diff --git a/.gitmodules b/.gitmodules index 9aeb1c754983..836d824a6f5a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,6 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm url = https://github.com/dmlc/tvm +[submodule "3rdparty/onnx-tensorrt"] + path = 3rdparty/onnx-tensorrt + url = https://github.com/onnx/onnx-tensorrt.git diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt new file mode 160000 index 000000000000..e7be19cff377 --- /dev/null +++ b/3rdparty/onnx-tensorrt @@ -0,0 +1 @@ +Subproject commit e7be19cff377a95817503e8525e20de34cdc574a diff --git a/CMakeLists.txt b/CMakeLists.txt index 483108a68419..8ff337ed159a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON) mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF) mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF) +mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF) message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}") if(USE_CUDA AND NOT USE_OLDCMAKECUDA) @@ -185,6 +186,36 @@ if(USE_VTUNE) list(APPEND mxnet_LINKER_LIBS dl) endif() +if(USE_TENSORRT) + message(STATUS "Using TensorRT") + set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/) + set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/) + + include_directories(${ONNX_PATH}) + include_directories(3rdparty/onnx-tensorrt/) + include_directories(3rdparty/) + add_definitions(-DMXNET_USE_TENSORRT=1) + add_definitions(-DONNX_NAMESPACE=onnx) + + find_package(Protobuf REQUIRED) + + find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED + PATHS ${ONNX_PATH} + DOC "Path to onnx library.") + find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED + PATHS ${ONNX_PATH} + DOC "Path to onnx_proto library.") + find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED + PATHS ${ONNX_TRT_PATH} + DOC "Path to onnx_proto library.") + find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED + PATHS ${ONNX_TRT_PATH} + DOC "Path to onnx_proto library.") + + list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY} + ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY}) +endif() + if(USE_MKLDNN) include(cmake/MklDnn.cmake) # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3). diff --git a/Jenkinsfile b/Jenkinsfile index 6d21f496426e..758e8e870eee 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -28,6 +28,7 @@ mx_dist_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3r mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' // timeout in minutes max_time = 120 // assign any caught errors here @@ -372,6 +373,17 @@ try { } } }, + 'TensorRT': { + node(NODE_LINUX_CPU) { + ws('workspace/build-tensorrt') { + timeout(time: max_time, unit: 'MINUTES') { + utils.init_git() + utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false) + utils.pack_lib('tensorrt', mx_tensorrt_lib) + } + } + } + }, 'Build CPU windows':{ node('mxnetwindows-cpu') { timeout(time: max_time, unit: 'MINUTES') { @@ -740,6 +752,22 @@ try { } } }, + 'Python3: TensorRT GPU': { + node(NODE_LINUX_GPU_P3) { + ws('workspace/build-tensorrt') { + timeout(time: max_time, unit: 'MINUTES') { + try { + utils.init_git() + utils.unpack_lib('tensorrt', mx_tensorrt_lib) + utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) + utils.publish_test_coverage() + } finally { + utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') + } + } + } + } + }, 'Scala: CPU': { node('mxnetlinux-cpu') { ws('workspace/ut-scala-cpu') { diff --git a/Makefile b/Makefile index 88f7dd9278cb..b794e00f00aa 100644 --- a/Makefile +++ b/Makefile @@ -91,6 +91,14 @@ else endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) + + +ifeq ($(USE_TENSORRT), 1) + CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1 + LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin +endif +# -L/usr/local/lib + ifeq ($(DEBUG), 1) NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) else diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py index 52d775b76925..a3c28f7118e9 100644 --- a/amalgamation/amalgamation.py +++ b/amalgamation/amalgamation.py @@ -23,13 +23,12 @@ import platform blacklist = [ - 'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', - 'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', - 'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h', - 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h', - 'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', - 'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', - 'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h', + 'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h', + 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h', + 'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h', + 'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h', + 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h', + 'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h', 'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp', 'relacy_shims.h', 'ittnotify.h', 'shared_mutex' ] @@ -150,6 +149,7 @@ def expand(x, pending, stage): h not in sysheaders and 'mkl' not in h and 'nnpack' not in h and + 'tensorrt' not in h and not h.endswith('.cuh')): sysheaders.append(h) else: expand.treeDepth += 1 diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt new file mode 100755 index 000000000000..255da316041f --- /dev/null +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt @@ -0,0 +1,41 @@ +# -*- mode: dockerfile -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Dockerfile to run MXNet on Ubuntu 16.04 for CPU + +FROM nvidia/cuda:9.0-cudnn7-devel + +WORKDIR /work/deps + +COPY install/ubuntu_core.sh /work/ +RUN /work/ubuntu_core.sh +COPY install/deb_ubuntu_ccache.sh /work/ +RUN /work/deb_ubuntu_ccache.sh +COPY install/ubuntu_python.sh /work/ +RUN /work/ubuntu_python.sh +COPY install/tensorrt.sh /work +RUN /work/tensorrt.sh + +ARG USER_ID=0 +COPY install/ubuntu_adduser.sh /work/ +RUN /work/ubuntu_adduser.sh + +COPY runtime_functions.sh /work/ + +WORKDIR /work/mxnet +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh new file mode 100755 index 000000000000..a6258d94f62f --- /dev/null +++ b/ci/docker/install/tensorrt.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Install gluoncv since we're testing Gluon models as well +pip2 install gluoncv==0.2.0 +pip3 install gluoncv==0.2.0 + +# Install Protobuf +# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt) +pushd . +cd .. +apt-get update +apt-get install -y automake libtool +git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git +cd protobuf +./autogen.sh +./configure +make -j$(nproc) +make install +ldconfig +popd + +# Install TensorRT +echo "TensorRT build enabled. Installing TensorRT." +wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb +dpkg -i tensorrt.deb +apt-get update +apt-get install -y --allow-downgrades libnvinfer-dev +rm tensorrt.deb diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index a0795eb58a5a..3e19eaf70049 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -436,6 +436,60 @@ build_ubuntu_gpu() { build_ubuntu_gpu_cuda91_cudnn7 } +build_ubuntu_gpu_tensorrt() { + + set -ex + + build_ccache_wrappers + + # Build ONNX + pushd . + echo "Installing ONNX." + cd 3rdparty/onnx-tensorrt/third_party/onnx + rm -rf build + mkdir -p build + cd build + cmake \ + -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\ + -DBUILD_SHARED_LIBS=ON ..\ + -G Ninja + ninja -v + export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH + export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH + popd + + # Build ONNX-TensorRT + pushd . + cd 3rdparty/onnx-tensorrt/ + mkdir -p build + cd build + cmake .. + make -j$(nproc) + export LIBRARY_PATH=`pwd`:$LIBRARY_PATH + popd + + mkdir -p /work/mxnet/lib/ + cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/ + cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/ + cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/ + + rm -rf build + make \ + DEV=1 \ + USE_BLAS=openblas \ + USE_CUDA=1 \ + USE_CUDA_PATH=/usr/local/cuda \ + USE_CUDNN=1 \ + USE_OPENCV=0 \ + USE_DIST_KVSTORE=0 \ + USE_TENSORRT=1 \ + USE_JEMALLOC=0 \ + USE_GPERFTOOLS=0 \ + ONNX_NAMESPACE=onnx \ + CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\ + -j$(nproc) +} + build_ubuntu_gpu_mkldnn() { set -ex @@ -638,6 +692,15 @@ unittest_ubuntu_python3_gpu_nocudnn() { nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu } +unittest_ubuntu_tensorrt_gpu() { + set -ex + export PYTHONPATH=./python/ + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH + python tests/python/tensorrt/lenet5_train.py + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/ +} + # quantization gpu currently only runs on P3 instances # need to separte it from unittest_ubuntu_python2_gpu() unittest_ubuntu_python2_quantization_gpu() { @@ -970,3 +1033,5 @@ EOF declare -F | cut -d' ' -f3 echo fi + + diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 75147cfd706d..58b1b1b4dafe 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1714,6 +1714,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping, NDArrayHandle** aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); + +/*! + * \brief get optimized graph from graph executor + */ +MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, + SymbolHandle *out); + /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 842653f86537..0ab04b86a0a1 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -166,6 +166,7 @@ class Executor { std::unordered_map* shared_data_arrays = nullptr, Executor* shared_exec = nullptr); + /*! * \brief the prototype of user-defined monitor callback */ diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 4df794bdfe37..5f4ae8bd0ac9 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -709,3 +709,19 @@ def write_all_str(module_file, module_all_list): module_op_file.close() write_all_str(module_internal_file, module_internal_all) module_internal_file.close() + +def cint(init_val=0): + """create a C int with an optional initial value""" + return C.c_int(init_val) + +def int_addr(x): + """given a c_int, return it's address as an int ptr""" + x_addr = C.addressof(x) + int_p = C.POINTER(C.c_int) + x_int_addr = C.cast(x_addr, int_p) + return x_int_addr + +def checked_call(f, *args): + """call a cuda function and check for success""" + error_t = f(*args) + assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t) diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index fbfd3469678b..606bb0ada54f 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -32,3 +32,4 @@ from . import io from . import quantization from . import quantization as quant +from . import tensorrt diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py new file mode 100644 index 000000000000..bb20767b3e6f --- /dev/null +++ b/python/mxnet/contrib/tensorrt.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Module to enable the use of TensorRT optimized graphs.""" + +import ctypes +import logging +import os + +from mxnet.symbol import Symbol + +from ..base import _LIB, SymbolHandle, MXNetError +from ..base import check_call + + +def set_use_tensorrt(status): + """ + Set an environment variable which will enable or disable the use of TensorRT in the backend. + Note: this is useful for A/B testing purposes. + :param status: Boolean, true if TensorRT optimization should be applied, False for legacy + behaviour. + """ + os.environ["MXNET_USE_TENSORRT"] = str(int(status)) + + +def get_use_tensorrt(): + """ + Get an environment variable which describes if TensorRT is currently enabled in the backend. + Note: this is useful for A/B testing purposes. + :return: Boolean, true if TensorRT optimization should be applied, False for legacy + behaviour. + """ + return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1) + + +def get_optimized_symbol(executor): + """ + Take an executor's underlying symbol graph and return its generated optimized version. + + Parameters + ---------- + executor : + An executor for which you want to see an optimized symbol. Getting an optimized symbol + is useful to compare and verify the work TensorRT has done against a legacy behaviour. + + Returns + ------- + symbol : nnvm::Symbol + The nnvm symbol optimized. + """ + handle = SymbolHandle() + try: + check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle))) + result = Symbol(handle=handle) + return result + except MXNetError: + logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure ' + 'build was compiled with MXNET_USE_TENSORRT enabled.') + raise + + +def tensorrt_bind(symbol, ctx, all_params, type_dict=None, stype_dict=None, group2ctx=None, + **kwargs): + """Bind current symbol to get an optimized trt executor. + + Parameters + ---------- + symbol : Symbol + The symbol you wish to bind, and optimize with TensorRT. + + ctx : Context + The device context the generated executor to run on. + + all_params : Dict of str->ndarray + A dictionary of mappings from parameter names to parameter NDArrays. + + type_dict : Dict of str->numpy.dtype + Input type dictionary, name->dtype + + stype_dict : Dict of str->str + Input storage type dictionary, name->storage_type + + group2ctx : Dict of string to mx.Context + The dict mapping the `ctx_group` attribute to the context assignment. + + kwargs : Dict of str->shape + Input shape dictionary, name->shape + + Returns + ------- + executor : mxnet.Executor + An optimized TensorRT executor. + """ + kwargs['shared_buffer'] = all_params + return symbol.simple_bind(ctx, type_dict=type_dict, stype_dict=stype_dict, + group2ctx=group2ctx, **kwargs) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index c0272c5bb433..fcd5406236e9 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -73,6 +73,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx): self.aux_arrays = [] self.outputs = self._get_outputs() self._symbol = copy.deepcopy(symbol) + self._optimized_symbol = None self._arg_dict = None self._grad_dict = None self._aux_dict = None diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 5d8e95077c40..c4050699bd52 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -592,8 +592,8 @@ def backward(self, out_grads=None): # pylint: disable=no-member og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start, end=islice.stop) - # pylint: enable=no-member out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i])) + # pylint: enable=no-member else: out_grads_slice.append(grad.copyto(self.contexts[i])) exec_.backward(out_grads=out_grads_slice) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 09bc23934e5a..b99350525bfa 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -26,6 +26,10 @@ #include #include #include "./c_api_common.h" +#include "../executor/graph_executor.h" +#if MXNET_USE_TENSORRT +#include "../executor/trt_graph_executor.h" +#endif // MXNET_USE_TENSORRT int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); @@ -439,13 +443,38 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector in_arg_vec; std::vector arg_grad_vec; std::vector aux_state_vec; - - *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, - grad_req_type_vec, shared_arg_name_set, &in_arg_vec, - &arg_grad_vec, &aux_state_vec, - use_shared_buffer ? &shared_buffer_map : nullptr, - reinterpret_cast(shared_exec_handle)); +#if MXNET_USE_TENSORRT + // If we've built with TensorRT support we by default return an TRTExecutor. + // Users can override this behaviour via env var, which is useful for example for A/B + // performance testing. + if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, + &arg_grad_ctx_vec, &aux_state_ctx_vec, + &arg_shape_map, &arg_dtype_map, &arg_stype_map, + &grad_req_type_vec, shared_arg_name_set, + &in_arg_vec, &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + } else { + // Checks to see if this env var has been set to true or false by the user. + // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn + // them and describe further steps. + const int unset_indicator = std::numeric_limits::quiet_NaN(); + if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) { + LOG(INFO) << "TensorRT not enabled by default. Please set the MXNET_USE_TENSORRT " + "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) " + "to enable."; + } +#endif // MXNET_USE_TENSORRT + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); +#if MXNET_USE_TENSORRT + } +#endif // MXNET_USE_TENSORRT // copy ndarray ptrs to ret->handles so that front end // can access them @@ -597,6 +626,25 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete out); } +int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, + SymbolHandle *out) { + auto s = new nnvm::Symbol(); + API_BEGIN(); + +#if MXNET_USE_TENSORRT + auto exec = static_cast(handle); + *s = exec->GetOptimizedSymbol(); + *out = s; +#else + LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with " + "MXNET_USE_TENSORRT enabled. Please re-compile MXNet with TensorRT support."; +#endif // MXNET_USE_TENSORRT + + API_END_HANDLE_ERROR(delete s); +} + + + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle) { diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 816599b955c1..fbe544221a35 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -24,10 +24,14 @@ #ifndef MXNET_COMMON_EXEC_UTILS_H_ #define MXNET_COMMON_EXEC_UTILS_H_ +#include +#include +#include #include #include #include #include "../common/utils.h" +#include "../executor/exec_pass.h" namespace mxnet { namespace common { @@ -366,6 +370,257 @@ inline void LogInferStorage(const nnvm::Graph& g) { } } +// prints a helpful message after shape inference errors in executor. +inline void HandleInferShapeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_shape << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " + "(0s means unknown dimensions). Please consider providing them as inputs:\n" + << oss.str(); +} + +// prints a helpful message after type inference errors in executor. +inline void HandleInferTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_dtype = inferred_dtypes[eid]; + if (inferred_dtype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_dtype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " + "(-1 means unknown dtype). Please consider providing them as inputs:\n" + << oss.str(); +} + +// prints a helpful message after storage type checking errors in executor. +inline void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const StorageTypeVector& inferred_stypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_stype = inferred_stypes[eid]; + if (inferred_stype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << common::stype_string(inferred_stype) << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments " + "(-1 means unknown stype). Please consider providing them as inputs:\n" + << oss.str(); +} + +/*! + * \brief If the requested ndarray's shape size is less than + * the corresponding shared_data_array's shape size and the + * storage type is shareable, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. + * Shareable storages include both default storage and row_sparse storage + * if enable_row_sparse_sharing is `True`, otherwise default storage only. + */ +inline NDArray ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, + const Context& ctx, + std::unordered_map* shared_buffer, + bool enable_row_sparse_sharing) { + bool stype_shareable = dest_arg_stype == kDefaultStorage; + if (enable_row_sparse_sharing) { + stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; + } + auto it = shared_buffer->find(name); + if (it != shared_buffer->end()) { + // check if size is large enough for sharing + bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size(); + if (size_shareable && stype_shareable) { // memory can be reused + CHECK_EQ(it->second.dtype(), dest_arg_dtype) + << "Requested arg array's dtype does not match that of the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), dest_arg_stype) + << "Requested arg array's stype does not match that of the reusable ndarray"; + return it->second.Reshape(dest_arg_shape); + } else if (stype_shareable) { + LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape + << ", which is larger than already allocated shape " << it->second.shape() + << ". Need to re-allocate. Consider putting default bucket key to be " + << "the bucket taking the largest input for better memory sharing."; + // size is not large enough, creating a larger one for sharing + // the NDArrays in shared_buffer are guaranteed to be of shareable storages + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + return it->second; + } else { + // not shareable storage + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } + } else { + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + if (stype_shareable) { + shared_buffer->emplace(name, ret); + } + return ret; + } // if (it != shared_buffer->end()) +} + +/*! + * \brief Assign context to the graph. + * This is triggered by both simple_bind and bind flows. + */ +inline nnvm::Graph AssignContext(nnvm::Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + size_t num_forward_inputs, + size_t num_forward_outputs) { + const auto& idx = g.indexed_graph(); + const auto& mutable_nodes = idx.mutable_input_nodes(); + // default use default context. + if (ctx_map.size() == 0) { + g.attrs["context"] = std::make_shared( + exec::ContextVector(idx.num_nodes(), default_ctx)); + for (const auto& x : in_arg_ctxes) { + CHECK(x == default_ctx) + << "Input array is in " << x << " while binding with ctx=" << default_ctx + << ". All arguments must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; + } + for (const auto& x : arg_grad_ctxes) { + CHECK(x == default_ctx) + << "Gradient array is in " << x << " while binding with ctx=" + << default_ctx << ". All gradients must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; + } + return g; + } + + // otherwise, use context assignment. + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id + nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id + nnvm::DeviceAssignMap device_map; // map arg name to device id + + // loop through the user input ctx_map and + // populate maps and lists + for (auto &kv : ctx_map) { + if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one + ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx + ctx_list.push_back(kv.second); // save ctx to the list + } + // assign device id to to the arg name with the corresponding ctx + device_map[kv.first] = ctx2id.at(kv.second); + } + + // loop through all the rest of input nodes not specified + // in the ctx_map and populate maps and lists + size_t arg_top = 0, aux_top = 0; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + Context ctx; + if (mutable_nodes.count(nid)) { // aux node is mutable + CHECK_LT(aux_top, aux_state_ctxes.size()); + ctx = aux_state_ctxes[aux_top]; + ++aux_top; + } else { // regular input node is immutable + CHECK_LT(arg_top, in_arg_ctxes.size()); + ctx = in_arg_ctxes[arg_top]; + ++arg_top; + } + if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id + ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id + ctx_list.push_back(ctx); // save the current ctx in the list + } + device[nid] = ctx2id.at(ctx); // assign device id to the current node + } + + // loop through backward input nodes and populate maps and lists + // the backward input nodes is the gradient of the loss wrt the output + size_t arg_grad_offset = 0; + // keep an offset into the arg_grad_ctxes vector, + // since g.outputs exclude arg_grad whose req == null + CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) + << "insufficient number of grad_reqs"; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + const uint32_t nid = idx.outputs()[i].node_id; + Context ctx = arg_grad_ctxes[arg_grad_offset]; + if (ctx2id.count(ctx) == 0) { + ctx2id[ctx] = static_cast(ctx_list.size()); + ctx_list.push_back(ctx); + } + int devid = ctx2id.at(ctx); + if (device[nid] != -1) { + CHECK_EQ(device[nid], devid) << "device of same output not equal to each other"; + } else { + device[nid] = devid; + } + } + + g.attrs["device"] = std::make_shared(std::move(device)); + g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); + const auto& assigned_device = g.GetAttr("device"); + + exec::ContextVector vcontext; + for (size_t i = 0; i < assigned_device.size(); ++i) { + if (assigned_device[i] == -1) { + vcontext.push_back(default_ctx); + } else { + vcontext.push_back(ctx_list[assigned_device[i]]); + } + } + + // after device planning, we should check again + // if the assigned device of gradient node + // corresponds to storage of grads + auto &new_idx = g.indexed_graph(); + arg_grad_offset = 0; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + const uint32_t nid = new_idx.outputs()[i].node_id; + Context ctx = arg_grad_ctxes[arg_grad_offset]; + CHECK(ctx == vcontext[nid]) + << "Trying to save gradient to " << ctx + << " while its source node \"" << new_idx[nid].source->attrs.name + << "\" computes it on " << vcontext[nid] + << ". Check your ctx in NDArray allocation."; + } + + g.attrs["context"] = std::make_shared(std::move(vcontext)); + return g; +} } // namespace common } // namespace mxnet diff --git a/src/common/serialization.h b/src/common/serialization.h new file mode 100644 index 000000000000..8a1bcc6e6ed2 --- /dev/null +++ b/src/common/serialization.h @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file serialization.h + * \brief Serialization of some STL and nnvm data-structures + * \author Clement Fuji Tsang + */ + +#ifndef MXNET_COMMON_SERIALIZATION_H_ +#define MXNET_COMMON_SERIALIZATION_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace mxnet { +namespace common { + +template +inline size_t SerializedSize(const T &obj); + +template +inline size_t SerializedSize(const nnvm::Tuple &obj); + +template +inline size_t SerializedSize(const std::map &obj); + +template<> +inline size_t SerializedSize(const std::string &obj); + +template +inline size_t SerializedSize(const std::tuple &obj); + +template +inline void Serialize(const T &obj, char **buffer); + +template +inline void Serialize(const nnvm::Tuple &obj, char **buffer); + +template +inline void Serialize(const std::map &obj, char **buffer); + +template<> +inline void Serialize(const std::string &obj, char **buffer); + +template +inline void Serialize(const std::tuple &obj, char **buffer); + +template +inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(std::map *obj, const std::string &buffer, size_t *curr_pos); + +template<> +inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(std::tuple *obj, const std::string &buffer, size_t *curr_pos); + + +template +struct is_container { + static const bool value = !std::is_pod::value; +}; + +template +inline size_t SerializedSize(const T &obj) { + return sizeof(T); +} + +template +inline size_t SerializedSize(const nnvm::Tuple &obj) { + if (is_container::value) { + size_t sum_val = 4; + for (const auto& el : obj) { + sum_val += SerializedSize(el); + } + return sum_val; + } else { + return 4 + (obj.ndim() * sizeof(T)); + } +} + +template +inline size_t SerializedSize(const std::map &obj) { + size_t sum_val = 4; + if (is_container::value && is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.first) + SerializedSize(p.second); + } + } else if (is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.first); + } + sum_val += sizeof(V) * obj.size(); + } else if (is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.second); + } + sum_val += sizeof(K) * obj.size(); + } else { + sum_val += (sizeof(K) + sizeof(V)) * obj.size(); + } + return sum_val; +} + +template<> +inline size_t SerializedSize(const std::string &obj) { + return obj.size() + 4; +} + +template +struct serialized_size_tuple { + template + static inline size_t Compute(const std::tuple &obj) { + return SerializedSize(std::get(obj)) + serialized_size_tuple::Compute(obj); + } +}; + +template<> +struct serialized_size_tuple<0> { + template + static inline size_t Compute(const std::tuple &obj) { + return SerializedSize(std::get<0>(obj)); + } +}; + +template +inline size_t SerializedSize(const std::tuple &obj) { + return serialized_size_tuple::Compute(obj); +} + +// Serializer + +template +inline size_t SerializedContainerSize(const T &obj, char **buffer) { + uint32_t size = obj.size(); + std::memcpy(*buffer, &size, 4); + *buffer += 4; + return (size_t) size; +} + +template +inline void Serialize(const T &obj, char **buffer) { + std::memcpy(*buffer, &obj, sizeof(T)); + *buffer += sizeof(T); +} + +template +inline void Serialize(const nnvm::Tuple &obj, char **buffer) { + uint32_t size = obj.ndim(); + std::memcpy(*buffer, &size, 4); + *buffer += 4; + for (auto& el : obj) { + Serialize(el, buffer); + } +} + +template +inline void Serialize(const std::map &obj, char **buffer) { + SerializedContainerSize(obj, buffer); + for (auto& p : obj) { + Serialize(p.first, buffer); + Serialize(p.second, buffer); + } +} + +template<> +inline void Serialize(const std::string &obj, char **buffer) { + auto size = SerializedContainerSize(obj, buffer); + std::memcpy(*buffer, &obj[0], size); + *buffer += size; +} + +template +struct serialize_tuple { + template + static inline void Compute(const std::tuple &obj, char **buffer) { + serialize_tuple::Compute(obj, buffer); + Serialize(std::get(obj), buffer); + } +}; + +template<> +struct serialize_tuple<0> { + template + static inline void Compute(const std::tuple &obj, char **buffer) { + Serialize(std::get<0>(obj), buffer); + } +}; + +template +inline void Serialize(const std::tuple &obj, char **buffer) { + serialize_tuple::Compute(obj, buffer); +} + +// Deserializer + +template +inline size_t DeserializedContainerSize(T *obj, const std::string &buffer, size_t *curr_pos) { + uint32_t size = obj->size(); + std::memcpy(&size, &buffer[*curr_pos], 4); + *curr_pos += 4; + return (size_t) size; +} + +template +inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) { + std::memcpy(obj, &buffer[*curr_pos], sizeof(T)); + *curr_pos += sizeof(T); +} + +template +inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos) { + uint32_t size = obj->ndim(); + std::memcpy(&size, &buffer[*curr_pos], 4); + *curr_pos += 4; + obj->SetDim(size); + for (size_t i = 0; i < size; ++i) { + Deserialize((*obj)[i], buffer, curr_pos); + } +} + +template +inline void Deserialize(std::map *obj, const std::string &buffer, size_t *curr_pos) { + auto size = DeserializedContainerSize(obj, buffer, curr_pos); + K first; + for (size_t i = 0; i < size; ++i) { + Deserialize(&first, buffer, curr_pos); + Deserialize(&(*obj)[first], buffer, curr_pos); + } +} + +template<> +inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos) { + auto size = DeserializedContainerSize(obj, buffer, curr_pos); + obj->resize(size); + std::memcpy(&(obj->front()), &buffer[*curr_pos], size); + *curr_pos += size; +} + +template +struct deserialize_tuple { + template + static inline void Compute(std::tuple *obj, + const std::string &buffer, size_t *curr_pos) { + deserialize_tuple::Compute(obj, buffer, curr_pos); + Deserialize(&std::get(*obj), buffer, curr_pos); + } +}; + +template<> +struct deserialize_tuple<0> { + template + static inline void Compute(std::tuple *obj, + const std::string &buffer, size_t *curr_pos) { + Deserialize(&std::get<0>(*obj), buffer, curr_pos); + } +}; + +template +inline void Deserialize(std::tuple *obj, const std::string &buffer, size_t *curr_pos) { + deserialize_tuple::Compute(obj, buffer, curr_pos); +} + + +template +inline void Serialize(const T& obj, std::string* serialized_data) { + serialized_data->resize(SerializedSize(obj)); + char* curr_pos = &(serialized_data->front()); + Serialize(obj, &curr_pos); + CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()), + serialized_data->size()); +} + +template +inline void Deserialize(T* obj, const std::string& serialized_data) { + size_t curr_pos = 0; + Deserialize(obj, serialized_data, &curr_pos); + CHECK_EQ(curr_pos, serialized_data.size()); +} + +} // namespace common +} // namespace mxnet +#endif // MXNET_COMMON_SERIALIZATION_H_ diff --git a/src/common/utils.h b/src/common/utils.h index 96949a047fba..fcc3da82b051 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -675,6 +675,37 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) { return k; } +/*! + * \brief Return an NDArray of all zeros. + */ +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +/*! + * \brief Helper to add a NDArray of zeros to a std::vector. + */ +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 26a249118940..8c483e9b2b8e 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -198,6 +198,18 @@ Graph InferStorageType(Graph&& graph, StorageTypeVector&& storage_type_inputs = StorageTypeVector(), const std::string& storage_type_attr_key = ""); +#if MXNET_USE_TENSORRT +/*! + * \brief Replace subgraphs by TRT (forward only) + */ +Graph ReplaceSubgraph(Graph&& g, + const std::unordered_set& set_subgraph, + std::unordered_map* const params_map); + +std::vector> GetTrtCompatibleSubsets(const Graph& g, + std::unordered_map* const params_map); +#endif + } // namespace exec } // namespace mxnet diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7386de4d12e3..6810800c8b71 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -37,6 +37,8 @@ namespace mxnet { namespace exec { +using namespace mxnet::common; + GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); need_grad_ = false; @@ -56,30 +58,6 @@ GraphExecutor::~GraphExecutor() { } } -inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype) { - // NDArray with default storage - if (stype == kDefaultStorage) { - NDArray ret(shape, ctx, false, dtype); - ret = 0; - return ret; - } - // NDArray with non-default storage. Storage allocation is always delayed. - return NDArray(stype, shape, ctx, true, dtype); -} - -inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec) { - // NDArray with default storage - if (stype == kDefaultStorage) { - vec->emplace_back(shape, ctx, false, dtype); - vec->back() = 0; - } else { - // NDArray with non-default storage. Storage allocation is always delayed. - vec->emplace_back(stype, shape, ctx, true, dtype); - } -} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -308,204 +286,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, return g; } -/*! - * \brief Assign context to the graph. - * This is triggered by both simple_bind and bind flows. - */ -static Graph AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - size_t num_forward_inputs, - size_t num_forward_outputs) { - const auto& idx = g.indexed_graph(); - const auto& mutable_nodes = idx.mutable_input_nodes(); - // default use default context. - if (ctx_map.size() == 0) { - g.attrs["context"] = std::make_shared( - ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_arg_ctxes) { - CHECK(x == default_ctx) - << "Input array is in " << x << " while binding with ctx=" << default_ctx - << ". All arguments must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; - } - for (const auto& x : arg_grad_ctxes) { - CHECK(x == default_ctx) - << "Gradient array is in " << x << " while binding with ctx=" - << default_ctx << ". All gradients must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; - } - return g; - } - - // otherwise, use context assignment. - std::map ctx2id; // map ctx to device id - std::vector ctx_list; // index is device id - nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id - nnvm::DeviceAssignMap device_map; // map arg name to device id - - // loop through the user input ctx_map and - // populate maps and lists - for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one - ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx - ctx_list.push_back(kv.second); // save ctx to the list - } - // assign device id to to the arg name with the corresponding ctx - device_map[kv.first] = ctx2id.at(kv.second); - } - - // loop through all the rest of input nodes not specified - // in the ctx_map and populate maps and lists - size_t arg_top = 0, aux_top = 0; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - Context ctx; - if (mutable_nodes.count(nid)) { // aux node is mutable - CHECK_LT(aux_top, aux_state_ctxes.size()); - ctx = aux_state_ctxes[aux_top]; - ++aux_top; - } else { // regular input node is immutable - CHECK_LT(arg_top, in_arg_ctxes.size()); - ctx = in_arg_ctxes[arg_top]; - ++arg_top; - } - if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id - ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id - ctx_list.push_back(ctx); // save the current ctx in the list - } - device[nid] = ctx2id.at(ctx); // assign device id to the current node - } - - // loop through backward input nodes and populate maps and lists - // the backward input nodes is the gradient of the loss wrt the output - size_t arg_grad_offset = 0; - // keep an offset into the arg_grad_ctxes vector, - // since g.outputs exclude arg_grad whose req == null - CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) - << "insufficient number of grad_reqs"; - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; - const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; - if (ctx2id.count(ctx) == 0) { - ctx2id[ctx] = static_cast(ctx_list.size()); - ctx_list.push_back(ctx); - } - int devid = ctx2id.at(ctx); - if (device[nid] != -1) { - CHECK_EQ(device[nid], devid) << "device of same output not equal to each other"; - } else { - device[nid] = devid; - } - } - - g.attrs["device"] = std::make_shared(std::move(device)); - g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); - const auto& assigned_device = g.GetAttr("device"); - - ContextVector vcontext; - for (size_t i = 0; i < assigned_device.size(); ++i) { - if (assigned_device[i] == -1) { - vcontext.push_back(default_ctx); - } else { - vcontext.push_back(ctx_list[assigned_device[i]]); - } - } - - // after device planning, we should check again - // if the assigned device of gradient node - // corresponds to storage of grads - auto &new_idx = g.indexed_graph(); - arg_grad_offset = 0; - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; - const uint32_t nid = new_idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; - CHECK(ctx == vcontext[nid]) - << "Trying to save gradient to " << ctx - << " while its source node \"" << new_idx[nid].source->attrs.name - << "\" computes it on " << vcontext[nid] - << ". Check your ctx in NDArray allocation."; - } - - g.attrs["context"] = std::make_shared(std::move(vcontext)); - return g; -} - -static void HandleInferShapeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const TShape& inferred_shape = inferred_shapes[eid]; - if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << inferred_shape << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " - "(0s means unknown dimensions). Please consider providing them as inputs:\n" - << oss.str(); -} - -static void HandleInferTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::DTypeVector& inferred_dtypes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const int inferred_dtype = inferred_dtypes[eid]; - if (inferred_dtype == -1) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << inferred_dtype << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " - "(-1 means unknown dtype). Please consider providing them as inputs:\n" - << oss.str(); -} - -static void HandleInferStorageTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const StorageTypeVector& inferred_stypes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const int inferred_stype = inferred_stypes[eid]; - if (inferred_stype == -1) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << common::stype_string(inferred_stype) << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments " - "(-1 means unknown stype). Please consider providing them as inputs:\n" - << oss.str(); -} - /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -680,57 +460,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } } -/*! - * \brief If the requested ndarray's shape size is less than - * the corresponding shared_data_array's shape size and the - * storage type is shareable, reuse the memory allocation - * in shared_buffer; otherwise, create a zero ndarray. - * Shareable storages include both default storage and row_sparse storage - * if enable_row_sparse_sharing is `True`, otherwise default storage only. - */ -static NDArray ReshapeOrCreate(const std::string& name, - const TShape& dest_arg_shape, - const int dest_arg_dtype, - const NDArrayStorageType dest_arg_stype, - const Context& ctx, - std::unordered_map* shared_buffer, - bool enable_row_sparse_sharing) { - bool stype_shareable = dest_arg_stype == kDefaultStorage; - if (enable_row_sparse_sharing) { - stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; - } - auto it = shared_buffer->find(name); - if (it != shared_buffer->end()) { - // check if size is large enough for sharing - bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size(); - if (size_shareable && stype_shareable) { // memory can be reused - CHECK_EQ(it->second.dtype(), dest_arg_dtype) - << "Requested arg array's dtype does not match that of the reusable ndarray"; - CHECK_EQ(it->second.storage_type(), dest_arg_stype) - << "Requested arg array's stype does not match that of the reusable ndarray"; - return it->second.Reshape(dest_arg_shape); - } else if (stype_shareable) { - LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape - << ", which is larger than already allocated shape " << it->second.shape() - << ". Need to re-allocate. Consider putting default bucket key to be " - << "the bucket taking the largest input for better memory sharing."; - // size is not large enough, creating a larger one for sharing - // the NDArrays in shared_buffer are guaranteed to be of shareable storages - it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - return it->second; - } else { - // not shareable storage - return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - } - } else { - auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - if (stype_shareable) { - shared_buffer->emplace(name, ret); - } - return ret; - } // if (it != shared_buffer->end()) -} - /*! * \brief Initialize in_args, arg_grads, and aux_states * and their data_entry_ of the executor using diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index bfc415b4526a..7b936c300254 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -163,20 +163,21 @@ class GraphExecutor : public Executor { std::vector* aux_state_vec); // Initialize in_args, arg_grads and aux_states with // shared_buffer and shared_exec - void InitArguments(const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes, - const nnvm::DTypeVector& inferred_dtypes, - const StorageTypeVector& inferred_stypes, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, - const Executor* shared_exec, - std::unordered_map* shared_buffer, - std::vector* in_arg_vec, - std::vector* arg_grad_vec, - std::vector* aux_state_vec); + virtual void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, const Context& default_ctx, @@ -212,7 +213,6 @@ class GraphExecutor : public Executor { void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); - // indicate whether there is a backward graph for gradients. bool need_grad_; // internal graph diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc new file mode 100644 index 000000000000..0b4d91be7009 --- /dev/null +++ b/src/executor/onnx_to_tensorrt.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file onnx_to_tensorrt.cc + * \brief TensorRT integration with the MXNet executor + * \author Marek Kolodziej, Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include "./onnx_to_tensorrt.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::cout; +using std::cerr; +using std::endl; + +namespace onnx_to_tensorrt { + +struct InferDeleter { + template + void operator()(T* obj) const { + if ( obj ) { + obj->destroy(); + } + } +}; + +template +inline std::shared_ptr InferObject(T* obj) { + if ( !obj ) { + throw std::runtime_error("Failed to create object"); + } + return std::shared_ptr(obj, InferDeleter()); +} + +std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) { + int onnx_ir_major = ir_version / 1000000; + int onnx_ir_minor = ir_version % 1000000 / 10000; + int onnx_ir_patch = ir_version % 10000; + return (std::to_string(onnx_ir_major) + "." + + std::to_string(onnx_ir_minor) + "." + + std::to_string(onnx_ir_patch)); +} + +void PrintVersion() { + cout << "Parser built against:" << endl; + cout << " ONNX IR version: " << onnx_ir_version_string(onnx::IR_VERSION) << endl; + cout << " TensorRT version: " + << NV_TENSORRT_MAJOR << "." + << NV_TENSORRT_MINOR << "." + << NV_TENSORRT_PATCH << endl; +} + +nvinfer1::ICudaEngine* onnxToTrtCtx( + const std::string& onnx_model, + int32_t max_batch_size, + size_t max_workspace_size, + nvinfer1::ILogger::Severity verbosity, + bool debug_builder) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + TRT_Logger trt_logger(verbosity); + auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger)); + auto trt_network = InferObject(trt_builder->createNetwork()); + auto trt_parser = InferObject(nvonnxparser::createParser( + *trt_network, trt_logger)); + ::ONNX_NAMESPACE::ModelProto parsed_model; + // We check for a valid parse, but the main effect is the side effect + // of populating parsed_model + if (!parsed_model.ParseFromString(onnx_model)) { + throw dmlc::Error("Could not parse ONNX from string"); + } + + if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) { + int nerror = trt_parser->getNbErrors(); + for ( int i=0; i < nerror; ++i ) { + nvonnxparser::IParserError const* error = trt_parser->getError(i); + if ( error->node() != -1 ) { + ::ONNX_NAMESPACE::NodeProto const& node = + parsed_model.graph().node(error->node()); + cerr << "While parsing node number " << error->node() + << " [" << node.op_type(); + if ( !node.output().empty() ) { + cerr << " -> \"" << node.output(0) << "\""; + } + cerr << "]:" << endl; + if ( static_cast(verbosity) >= \ + static_cast(nvinfer1::ILogger::Severity::kINFO) ) { + cerr << "--- Begin node ---" << endl; + cerr << node.DebugString() << endl; + cerr << "--- End node ---" << endl; + } + } + cerr << "ERROR: " + << error->file() << ":" << error->line() + << " In function " << error->func() << ":\n" + << "[" << static_cast(error->code()) << "] " << error->desc() + << endl; + } + throw dmlc::Error("Cannot parse ONNX into TensorRT Engine"); + } + + bool fp16 = trt_builder->platformHasFastFp16(); + + trt_builder->setMaxBatchSize(max_batch_size); + trt_builder->setMaxWorkspaceSize(max_workspace_size); + if ( fp16 && dmlc::GetEnv("MXNET_TENSORRT_USE_FP16_FOR_FP32", false) ) { + LOG(INFO) << "WARNING: TensorRT using fp16 given original MXNet graph in fp32 !!!"; + trt_builder->setHalf2Mode(true); + } + + trt_builder->setDebugSync(debug_builder); + nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get()); + return trt_engine; +} + +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h new file mode 100644 index 000000000000..259cfce7c332 --- /dev/null +++ b/src/executor/onnx_to_tensorrt.h @@ -0,0 +1,77 @@ +#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ +#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file onnx_to_tensorrt.h + * \brief TensorRT integration with the MXNet executor + * \author Marek Kolodziej, Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include + +#include "../operator/contrib/tensorrt-inl.h" + +namespace onnx_to_tensorrt { + +class TRT_Logger : public nvinfer1::ILogger { + nvinfer1::ILogger::Severity _verbosity; + std::ostream* _ostream; + public: + TRT_Logger(Severity verbosity = Severity::kWARNING, + std::ostream& ostream = std::cout) + : _verbosity(verbosity), _ostream(&ostream) {} + void log(Severity severity, const char* msg) override { + if ( severity <= _verbosity ) { + time_t rawtime = std::time(0); + char buf[256]; + strftime(&buf[0], 256, + "%Y-%m-%d %H:%M:%S", + std::gmtime(&rawtime)); + const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : + severity == Severity::kERROR ? " ERROR" : + severity == Severity::kWARNING ? "WARNING" : + severity == Severity::kINFO ? " INFO" : + "UNKNOWN"); + (*_ostream) << "[" << buf << " " << sevstr << "] " + << msg + << std::endl; + } + } +}; + +nvinfer1::ICudaEngine* onnxToTrtCtx( + const std::string& onnx_model, + int32_t max_batch_size = 32, + size_t max_workspace_size = 1L << 30, + nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING, + bool debug_builder = false); +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc new file mode 100644 index 000000000000..b5fc8d15f7ac --- /dev/null +++ b/src/executor/tensorrt_pass.cc @@ -0,0 +1,596 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt_pass.cc + * \brief Replace TRT compatible subgraphs by TRT engines + * \author Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include + +#include "../operator/contrib/nnvm_to_onnx-inl.h" +#include "./exec_pass.h" +#include "./onnx_to_tensorrt.h" + +namespace mxnet { +namespace exec { + +using NodePtr = nnvm::NodePtr; + +/*! + * \brief Custom graph class, which will contain bi-directional nodes + * we need to compute DFS and reverse DFS for graph partitioning + */ +class BidirectionalGraph { + public: + struct Node { + nnvm::Node* nnvmptr; + std::vector inputs; + std::vector outputs; + }; + std::vector nodes; + std::unordered_map nnvm2nid; + std::vector outputs; + static const std::unordered_set unconditionalTRTop; + + explicit BidirectionalGraph(const Graph &g) { + auto& idx = g.indexed_graph(); + auto num_nodes = idx.num_nodes(); + nodes.reserve(num_nodes); + nnvm2nid.reserve(num_nodes); + outputs.reserve(idx.outputs().size()); + DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) { + BidirectionalGraph::Node new_node; + new_node.nnvmptr = n.get(); + nnvm2nid[n.get()] = static_cast(nodes.size()); + nodes.emplace_back(std::move(new_node)); + }); + for (const auto& it : nnvm2nid) { + nnvm::Node* nnvmnode = it.first; + uint32_t nid = it.second; + for (auto& n : nnvmnode->inputs) { + uint32_t input_nid = nnvm2nid[n.node.get()]; + nodes[input_nid].outputs.emplace_back(&nodes[nid]); + nodes[nid].inputs.emplace_back(&nodes[input_nid]); + } + } + for (auto& e : g.outputs) { + uint32_t nid = nnvm2nid[e.node.get()]; + outputs.emplace_back(&nodes[nid]); + } + } + + template + void DFS(const std::vector& heads, bool reverse, FVisit fvisit) { + std::unordered_set visited; + std::vector vec(heads.begin(), heads.end()); + visited.reserve(heads.size()); + while (!vec.empty()) { + Node* vertex = vec.back(); + vec.pop_back(); + if (visited.count(vertex) == 0) { + visited.insert(vertex); + fvisit(vertex); + std::vector nexts = reverse ? vertex->inputs : vertex->outputs; + for (Node* node : nexts) { + if (visited.count(node) == 0) { + vec.emplace_back(node); + } + } + } + } + } + + using t_pairset = std::pair, std::unordered_set>; + using t_pairvec = std::pair, std::vector>; + using t_uncomp_map = std::unordered_map>; + + std::unordered_set naive_grow_subgraph(Node* head, + std::unordered_set* set_unused, + t_uncomp_map* uncomp_map) { + std::unordered_set subgraph; + std::unordered_set uncomp_set; + std::deque stack; + stack.emplace_back(head); + while (!stack.empty()) { + Node* vertex = stack.back(); + stack.pop_back(); + if (set_unused->count(vertex) && !uncomp_set.count(vertex)) { + set_unused->erase(vertex); + subgraph.insert(vertex); + uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end()); + for (Node* input : vertex->inputs) { + if (set_unused->count(input) && !uncomp_set.count(input)) { + stack.emplace_back(input); + } + } + for (Node* output : vertex->outputs) { + if (set_unused->count(output) && !uncomp_set.count(output)) { + stack.emplace_back(output); + } + } + } + } + return subgraph; + } + + std::vector> get_subsets( + std::unordered_map* const params_map) { + std::vector> subgraphs; + std::unordered_set set_nonTRTnodes; + std::unordered_set set_allnodes(nodes.size()); + std::vector separation_sets; + for (Node& node : nodes) { + if (!IsTRTCompatible(node.nnvmptr)) { + set_nonTRTnodes.insert(&node); + std::unordered_set in_graph; + std::unordered_set out_graph; + std::vector dummy_head; + dummy_head.emplace_back(&node); + DFS(dummy_head, false, [&out_graph](Node* node) { + out_graph.insert(node); + }); + DFS(dummy_head, true, [&in_graph](Node* node) { + in_graph.insert(node); + }); + separation_sets.emplace_back(std::make_pair(in_graph, out_graph)); + } + set_allnodes.emplace(&node); + } + t_uncomp_map uncomp_map; + std::unordered_set set_TRTnodes; + set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end()); + for (Node* n : set_nonTRTnodes) { + set_TRTnodes.erase(n); + } + for (Node* n : set_TRTnodes) { + for (t_pairset p : separation_sets) { + if (p.first.count(n)) { + uncomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + uncomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + for (Node* nonTRTn : set_nonTRTnodes) { + uncomp_map[n].erase(nonTRTn); + } + } + std::unordered_set set_unused; + set_unused.reserve(set_TRTnodes.size()); + + for (auto& n : set_TRTnodes) { + if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) { + set_unused.insert(n); + } + } + std::unordered_set visited; + std::deque stack(outputs.begin(), outputs.end()); + while (!stack.empty()) { + Node* vertex = stack.front(); + stack.pop_front(); + if (!visited.count(vertex)) { + visited.insert(vertex); + if (set_unused.count(vertex)) { + subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map)); + } + for (Node* input : vertex->inputs) { + stack.emplace_back(input); + } + } + } + + return subgraphs; + } + + + private: + friend class Graph; + + bool IsTRTCompatible(nnvm::Node* nodeptr) { + if (nodeptr->op() == nullptr) { + return true; + } + + const std::string op_name = nodeptr->op()->name; + if (op_name == "Pooling") { + return (nodeptr->attrs.dict.at("pool_type") == "avg" || + nodeptr->attrs.dict.at("pool_type") == "max"); + } + + if (unconditionalTRTop.count(op_name)) { + return true; + } + + if (op_name == "Activation") { + return nodeptr->attrs.dict.at("act_type") == "relu" || + nodeptr->attrs.dict.at("act_type") == "tanh" || + nodeptr->attrs.dict.at("act_type") == "sigmoid"; + } + + return false; + } +}; // class BidirectionalGraph + +/*! + * \brief function which transform std::vector back to Attrs (dmlc::any) + */ +const std::unordered_set BidirectionalGraph::unconditionalTRTop = { + "Convolution", + "BatchNorm", + "elemwise_add", + "elemwise_sub", + "elemwise_mul", + "rsqrt", + "pad", + "Pad", + "mean", + "FullyConnected", + "Flatten", + "SoftmaxOutput", +}; + + +using NodeEntrySet = std::unordered_set; + +/*! + * \brief get the output nodes of the subgraph in the main graph + * \return a vector of the output nodes +*/ +std::vector GetSubgraphNodeEntries(Graph g, + std::unordered_set set_subgraph) { + std::vector outputs; + NodeEntrySet _outputs; + for (auto& e : g.outputs) { + if (set_subgraph.count(e.node.get())) { + _outputs.insert(e); + } + } + DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){ + if (!set_subgraph.count(node.get())) { + for (auto& e : node->inputs) { + if (set_subgraph.count(e.node.get())) { + _outputs.insert(e); + } + } + } + }); + outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end()); + return outputs; +} + + +/*! + * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph + * \return a vector the nodes +*/ +std::vector GetSubgraphInterfaceNodes(Graph g, + std::unordered_set set_subgraph) { + std::vector inputs; + NodeEntrySet _inputs; + DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){ + if (set_subgraph.count(node.get())) { + for (auto& e : node->inputs) { + if (!set_subgraph.count(e.node.get())) { + _inputs.insert(e); + } + } + } + }); + inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end()); + return inputs; +} + +std::unordered_map GetGraphInputsMap(const Graph& g) { + std::unordered_map outputs; + auto& idx = g.indexed_graph(); + outputs.reserve(idx.num_nodes()); + std::vector input_nodes = idx.input_nodes(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + outputs[input_nodes[i]] = static_cast(i); + } + return outputs; +} + +/*! + * \brief Dummy function which creates a fake TensorRT Node + */ +nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g, + std::unordered_map* const params_map) { + auto p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("_trt_op"); + op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map); + p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map; + p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map; + p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph; + if (p->op()->attr_parser != nullptr) { + p->op()->attr_parser(&(p->attrs)); + } + return p; +} + +/*! + * \brief Update attributes of the graph (such as some inputs properties) + */ +Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g, + const std::unordered_map& old2new, + const nnvm::NodeEntryMap& main_input_entry_to_sub) { + const auto& idx = g.indexed_graph(); + const auto& sub_idx = subgraph.indexed_graph(); + + const auto& shape = g.GetAttr("shape"); + const auto& dtype = g.GetAttr("dtype"); + const auto& storage_type = g.GetAttr("storage_type"); + const auto& shape_inputs = g.GetAttr("shape_inputs"); + const auto& dtype_inputs = g.GetAttr("dtype_inputs"); + const auto& storage_type_inputs = g.GetAttr("storage_type_inputs"); + + nnvm::ShapeVector sub_shape(sub_idx.num_node_entries()); + nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries()); + StorageTypeVector sub_storage_type(sub_idx.num_node_entries()); + nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size()); + nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size()); + StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size()); + + const std::unordered_map inputsindex2pos = GetGraphInputsMap(g); + const std::unordered_map sub_inputsindex2pos = GetGraphInputsMap(subgraph); + // map attributes from graph to subgraph + for (auto& p : old2new) { + const uint32_t nid = idx.node_id(p.first); + const uint32_t sub_nid = sub_idx.node_id(p.second.get()); + const nnvm::Op* op = sub_idx[sub_nid].source->op(); + if (op == nullptr) { // if it's an input node, there is only one output node entry + const uint32_t sub_i = sub_idx.entry_id(sub_nid, 0); + const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid); + const uint32_t i = idx.entry_id(nid, 0); + + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)]; + sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)]; + sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)]; + + } else { + for (size_t oi = 0; oi < op->num_outputs; ++oi) { + const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi); + const uint32_t i = idx.entry_id(nid, oi); + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + } + } + } + // old2new doesn't contain placeholder / interfaces + for (auto& p : main_input_entry_to_sub) { + nnvm::NodeEntry main_entry = p.first; + nnvm::NodeEntry sub_entry = p.second; + const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get()); + const uint32_t sub_i = sub_idx.entry_id(sub_entry); + const uint32_t i = idx.entry_id(main_entry); + const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid); + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + sub_shape_inputs[sub_input_i] = sub_shape[sub_i]; + sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i]; + sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i]; + } + subgraph.attrs["shape"] = + std::make_shared(std::move(sub_shape)); + subgraph.attrs["dtype"] = + std::make_shared(std::move(sub_dtype)); + subgraph.attrs["storage_type"] = + std::make_shared(std::move(sub_storage_type)); + subgraph.attrs["shape_inputs"] = + std::make_shared(std::move(sub_shape_inputs)); + subgraph.attrs["dtype_inputs"] = + std::make_shared(std::move(sub_dtype_inputs)); + subgraph.attrs["storage_type_inputs"] = + std::make_shared(std::move(sub_storage_type_inputs)); + + return subgraph; +} + +/*! + * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined + */ +const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) { + const std::string name_prefix("TRT_node"); + std::unordered_set name_set; + DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) { + if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) { + name_set.insert(node->attrs.name); + } + }); + // name inside the subgraph will be avaible as they will be removed + DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) { + if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) { + name_set.erase(node->attrs.name); + } + }); + uint32_t name_suffix = 0; + std::string full_name = name_prefix + std::to_string(name_suffix); + while (name_set.count(full_name)) { + full_name = name_prefix + std::to_string(++name_suffix); + } + return full_name; +} + +/*! + * \brief helper function to display what nodes are in a specific subset + */ +void dispNodesSet(Graph g, std::unordered_set s) { + DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){ + if (s.count(n.get())) { + std::cout << " Y " << n->attrs.name << std::endl; + } else { + std::cout << " N " << n->attrs.name << std::endl; + } + }); +} + +/*! + * \brief Replace a set of nodes by a TensorRT node + */ +Graph ReplaceSubgraph(Graph&& g, + const std::unordered_set& set_subgraph, + std::unordered_map* const params_map) { + // Create MXNet subgraph + Graph subgraph; + + const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph); + subgraph.outputs = sub_outputs_in_main; + // old2new will link raw pointer of the nodes in the graph to + // the corresponding shared_ptr of the nodes in the generated subgraph + std::unordered_map old2new; + std::deque stack; + std::unordered_set visited; + int32_t reservation = set_subgraph.size(); + old2new.reserve(reservation); + visited.reserve(reservation); + + // Create the shared_ptr using the same raw pointer don't really matter + for (auto& n : set_subgraph) { + old2new[n] = std::make_shared(*n); + } + + // To generate a subgraph an input have to be replace by data node (no op) + // and it have to be agnostic to the node from which it's an output + // (For exemple even if two inputs are two different outputs from the same node) + nnvm::NodeEntryMap main_input_entry_to_sub; + for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) { + auto node = nnvm::Node::Create(); + node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index); + auto new_e = nnvm::NodeEntry{node, 0, 0}; + main_input_entry_to_sub[e] = new_e; + } + + for (nnvm::NodeEntry& e : subgraph.outputs) { + e.node = old2new[e.node.get()]; + stack.emplace_back(e.node.get()); + } + // link all nodes in the subgraph to nodes in the subgraph instead of main graph + while (!stack.empty()) { + auto vertex = stack.front(); + stack.pop_front(); + if (!visited.count(vertex)) { + visited.insert(vertex); + for (auto& e : vertex->inputs) { + auto it = main_input_entry_to_sub.find(e); + if (it != main_input_entry_to_sub.end()) { + e = it->second; + } else { + e.node = old2new[e.node.get()]; + } + stack.emplace_back(e.node.get()); + } + } + } + // Remove the control dependencies of the subgraph to nodes that are not in the subgraph + DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) { + std::remove_if(node->control_deps.begin(), + node->control_deps.end(), + [&set_subgraph](nnvm::NodePtr n_ptr) { + return !set_subgraph.count(n_ptr.get()); + }); + for (nnvm::NodePtr& n_ptr : node->control_deps) { + n_ptr = old2new[n_ptr.get()]; + } + }); + + subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub); + auto& sub_idx = subgraph.indexed_graph(); + + auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map); + trtnodeptr->attrs.name = GetNewTrtName(g, subgraph); + + // Insert new trt node and unplug replaced nodes + std::unordered_map sub_input_entryid_to_main; + for (auto& p : main_input_entry_to_sub) { + sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first; + } + + // Plug the nodes from the main graph as inputs of the trt node + trtnodeptr->inputs.resize(main_input_entry_to_sub.size()); + { + uint32_t counter = 0; + for (uint32_t i : sub_idx.input_nodes()) { + auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0)); + if (it != sub_input_entryid_to_main.end()) { + trtnodeptr->inputs[counter++] = it->second; + } + } + } + nnvm::NodeEntryMap sub_outputs_in_main_to_pos; + for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) { + sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i; + } + // Plug the trt node as inputs to the main graph nodes + DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) { + for (auto& e : n->inputs) { + auto it = sub_outputs_in_main_to_pos.find(e); + if (it != sub_outputs_in_main_to_pos.end()) { + e.index = it->second; + e.node = trtnodeptr; + } + } + }); + + for (auto& output : g.outputs) { + auto it = sub_outputs_in_main_to_pos.find(output); + if (it != sub_outputs_in_main_to_pos.end()) { + output.index = it->second; + output.node = trtnodeptr; + } + } + + Graph new_graph; + new_graph.outputs = g.outputs; + return new_graph; +} + +std::vector> GetTrtCompatibleSubsets(const Graph& g, + std::unordered_map* const params_map) { + BidirectionalGraph biG = BidirectionalGraph(g); + std::vector> subsets = biG.get_subsets(params_map); + std::vector> nnvm_subsets(subsets.size(), + std::unordered_set()); + for (size_t i = 0; i < subsets.size(); ++i) { + nnvm_subsets[i].reserve(subsets[i].size()); + for (auto& n : subsets[i]) { + nnvm_subsets[i].insert(n->nnvmptr); + } + } + return nnvm_subsets; +} + +} // namespace exec +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc new file mode 100644 index 000000000000..65dbb29792e0 --- /dev/null +++ b/src/executor/trt_graph_executor.cc @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_TENSORRT + +#include "trt_graph_executor.h" + +#include +#include +#include "./onnx_to_tensorrt.h" +#include "../operator/contrib/tensorrt-inl.h" +#include "../common/utils.h" +#include "../common/exec_utils.h" + + +namespace mxnet { +namespace exec { + +using namespace mxnet::common; + + /*! + * \brief TrtGraphExecutor initializer for simple bind flow in + * which only certain input shapes and dtypes are provided by users. + * The initializer uses these shapes and dtypes to perform + * shape and dtype inferences, and then create NDArrays + * to populate data entries of the graph. The created NDArrays + * for in_args, arg_grads and aux_states are passed to the + * front end to attach the created executor. + * In front end, if the simple_bind flow is trigger by + * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup + * and shared executor will be taken into account in creating + * NDArrays for in_args, arg_grads, and aux_states for reusing + * already allocated memory. + * + * This version of an executor exports the computation graph to TensorRT make use of fused + * kernels and other runtime enhancements. TRT will compile the sub-graphs to executable fused + * operators without intervention from the user. Operators in the original graph that are not + * supported by TRT will continue to be executed normally by MXNet. + * + */ +void TrtGraphExecutor::Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + symbol = symbol.Copy(); + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types); + + if (need_grad_) { + LOG(FATAL) << "You may be attempting to use TensorRT for training. TensorRT is an inference " + "only library. To re-enable legacy MXNet graph execution, which will support " + "training, set the MXNET_USE_TENSORRT environment variable to 0, or call " + "mx.contrib.tensorrt.set_use_tensorrt(False)"; + } + + if (shared_buffer == nullptr || shared_buffer->empty()) { + LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty. " + << "Please provide weights and other parameters, such as " + << "BatchNorm moments, via the shared_buffer, during simple bind call."; + } + + // The following code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize arg_shapes and arg_dtypes for shape and type inferences. + // It contains all in_args and aux_states' shapes and types in a certain order. + const nnvm::IndexedGraph& idx = g.indexed_graph(); + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { + arg_stypes[i] = it3->second; + } + } + g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } + + g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + + auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); + for (auto trt_group : trt_groups) { + if (trt_group.size() > 1) { + g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer); + g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map, + arg_stype_map, shared_buffer); + } + } + + + InitArguments(g.indexed_graph(), g.GetAttr("shape"), + g.GetAttr("dtype"), + g.GetAttr("storage_type"), + *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, + *grad_req_types, shared_arg_names, shared_exec, + shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); + + // The above code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor using + * shared_buffer from DataParallelExecutorGroup + * and shared_exec if available. + */ +void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + // initialize in_args, arg_grads, and aux_states and populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + const auto& mutable_nodes = idx.mutable_input_nodes(); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec) { + const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); + CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage) + << "Non-default storage type detected when creating auxilliary NDArray. The allocated " + << "memory of shared_exec.aux_array cannot be resued for argument: " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_shape, aux_nd.shape()) + << "Inferred shape does not match shared_exec.aux_array's shape." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument: " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, aux_nd.dtype()) + << "Inferred dtype does not match shared_exec.aux_array's dtype." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument: " + << arg_name << " for the current executor"; + aux_state_vec->emplace_back(aux_nd); + } else { + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top]))); + } else { + aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + aux_state_ctxes[aux_top], inferred_dtype))); + } + } // if (has_shared_exec) + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; + } else { // in_args and grad for in_args + if (shared_arg_names.count(arg_name)) { // model parameter + // model parameter + if (nullptr != shared_exec) { + const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); + auto arg_nd_stype = in_arg_nd.storage_type(); + // for model parameter, both default storage and row_sparse storage can be shared + bool shareable_arg_stype = inferred_stype == kDefaultStorage || + inferred_stype == kRowSparseStorage; + // try to reuse memory from shared_exec + CHECK(shareable_arg_stype) << "Inferred storage type " + << common::stype_string(inferred_stype) + << " does not support memory sharing with shared_exec.arg_array"; + CHECK_EQ(inferred_stype, arg_nd_stype) + << "Inferred stype does not match shared_exec.arg_array's stype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_shape, in_arg_nd.shape()) + << "Inferred shape does not match shared_exec.arg_array's shape" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, in_arg_nd.dtype()) + << "Inferred dtype does not match shared_exec.arg_array's dtype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + in_arg_vec->emplace_back(in_arg_nd); + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec + arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); + } else { + // no need to reuse memory from shared_exec for gradient of non-default storage + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } + } else { // !shared_arg_names.count(arg_name) + // model parameter, row_sparse ndarray sharing enabled + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); + } else { + in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + in_arg_ctxes[arg_top], inferred_dtype))); + } + // gradient for model parameter, row_sparse ndarray sharing disabled + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + bool enable_row_sparse_sharing = false; + arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer, + enable_row_sparse_sharing)); + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } // if (shared_arg_names.count(arg_name)) + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + if (!arg_grad_vec->back().is_none()) { + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + data_entry_[eid] = in_arg_vec->back(); + ++arg_top; + } + } +} + + + /*! + * \brief This function is triggered after each tensorrt subgraph replacement pass. + * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases) + * are absorbed into the TRT engine it also it reruns attributes inferences accordingly + * to the new topology. + */ +Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::vector *grad_req_types, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::unordered_map *params_map) { + std::unordered_set to_remove_params; + for (auto& el : *params_map) { + to_remove_params.insert(el.first); + } + + DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) { + to_remove_params.erase(n->attrs.name); + }); + + for (auto& el : to_remove_params) { + params_map->erase(el); + arg_shape_map->erase(el); + arg_dtype_map->erase(el); + arg_stype_map->erase(el); + } + const auto &idx = g.indexed_graph(); + num_forward_inputs_ = idx.input_nodes().size(); + in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + aux_state_ctxes->resize(idx.mutable_input_nodes().size()); + + // create "device" and "context" attrs for the graph + g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types, num_forward_inputs_, + num_forward_outputs_); + + // get number of nodes used in forward pass + num_forward_nodes_ = 0; + for (size_t i = 0; i < num_forward_outputs_; ++i) { + num_forward_nodes_ = std::max( + num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); + } + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string &name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { + arg_stypes[i] = it3->second; + } + } + g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } + + g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + + return g; +} + + +/*! + * \brief Return the "optimized" symbol contained in the graph. + * For optimization pass such as TensorRT pass + */ +nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() { + Symbol ret; + ret.outputs = std::vector(graph_.outputs.begin(), + graph_.outputs.begin() + num_forward_outputs_); + ret = ret.Copy(); + static const Op* trt_op = Op::Get("_trt_op"); + DFSVisit(ret.outputs, [](const nnvm::NodePtr n) { + if (n->op() == trt_op) { + n->attrs.dict.clear(); + } + }); + return ret; +} + +Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol, + const Context &default_ctx, + const std::map &group2ctx, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set ¶m_names, + std::vector *in_args, + std::vector *arg_grads, + std::vector *aux_states, + std::unordered_map *shared_buffer, + Executor *shared_exec) { + auto exec = new exec::TrtGraphExecutor(); + exec->Init(symbol, default_ctx, group2ctx, + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_types, param_names, + in_args, arg_grads, aux_states, + shared_buffer, shared_exec); + return exec; +} + +} // namespace exec + +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h new file mode 100644 index 000000000000..96ac4426270a --- /dev/null +++ b/src/executor/trt_graph_executor.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ +#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ + +#if MXNET_USE_TENSORRT + +#include +#include +#include + +#include "./graph_executor.h" + +namespace mxnet { + +namespace exec { + +class TrtGraphExecutor : public GraphExecutor { + public: + static Executor* TensorRTBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + std::vector *in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* + shared_data_arrays = nullptr, + Executor* shared_exec = nullptr); + + virtual void Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer = nullptr, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); + + // Returns symbol representing the TRT optimized graph for comparison purposes. + nnvm::Symbol GetOptimizedSymbol(); + + protected: + Graph ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::vector *grad_req_types, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::unordered_map *params_map); + + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) override; +}; + +} // namespace exec + +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h new file mode 100644 index 000000000000..58f88b051433 --- /dev/null +++ b/src/operator/contrib/nnvm_to_onnx-inl.h @@ -0,0 +1,156 @@ +#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ +#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT Operator + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./tensorrt-inl.h" +#include "../operator_common.h" +#include "../../common/utils.h" +#include "../../common/serialization.h" + +namespace mxnet { +namespace op { +namespace nnvm_to_onnx { + +using namespace nnvm; +using namespace ::onnx; +using int64 = ::google::protobuf::int64; + +std::unordered_map GetPlaceholderShapes(const ShapeVector& shape_inputs, + const nnvm::IndexedGraph& ig); + +std::unordered_map GetOutputLookup(const nnvm::IndexedGraph& ig); + +void ConvertPlaceholder( + const std::string& node_name, + const std::unordered_map& placeholder_shapes, + GraphProto* const graph_proto); + +void ConvertConstant(GraphProto* const graph_proto, + const std::string& node_name, + std::unordered_map* const shared_buffer); + +void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map, + GraphProto* const graph_proto, + const std::unordered_map::iterator& out_iter, + const std::string& node_name, + const nnvm::Graph& g, + const StorageTypeVector& storage_types, + const DTypeVector& dtypes); + +typedef void (*ConverterFunction)(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + + +// Forward declarations +void ConvertConvolution( + NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + + +void ConvertPooling(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertActivation(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertFullyConnected(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertSoftmaxOutput(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertFlatten(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertBatchNorm(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertElementwiseAdd(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +TRTParam ConvertNnvmGraphToOnnx( + const nnvm::Graph &g, + std::unordered_map *const shared_buffer); + +static const std::unordered_map converter_map = { + {"Convolution", ConvertConvolution}, + {"Pooling", ConvertPooling}, + {"Activation", ConvertActivation}, + {"FullyConnected", ConvertFullyConnected}, + {"SoftmaxOutput", ConvertSoftmaxOutput}, + {"Flatten", ConvertFlatten}, + {"BatchNorm", ConvertBatchNorm}, + {"elemwise_add", ConvertElementwiseAdd}}; + +} // namespace nnvm_to_onnx +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc new file mode 100644 index 000000000000..902466614c7c --- /dev/null +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cc + * \brief TensorRT operation registration + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./nnvm_to_onnx-inl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../common/serialization.h" +#include "../../common/utils.h" +#include "../../ndarray/ndarray_function.h" +#include "../../operator/nn/activation-inl.h" +#include "../../operator/nn/batch_norm-inl.h" +#include "../../operator/nn/convolution-inl.h" +#include "../../operator/nn/fully_connected-inl.h" +#include "../../operator/nn/pooling-inl.h" +#include "../../operator/softmax_output-inl.h" +#include "./tensorrt-inl.h" + +#if MXNET_USE_TENSORRT_ONNX_CHECKER +#include +#endif // MXNET_USE_TENSORRT_ONNX_CHECKER + +namespace mxnet { +namespace op { +namespace nnvm_to_onnx { + +op::TRTParam ConvertNnvmGraphToOnnx( + const nnvm::Graph& g, + std::unordered_map* const shared_buffer) { + op::TRTParam trt_param; + op::tensorrt::NameToIdx_t trt_input_map; + op::tensorrt::InferenceMap_t trt_output_map; + + const nnvm::IndexedGraph& ig = g.indexed_graph(); + const auto& storage_types = g.GetAttr("storage_type"); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shape_inputs = g.GetAttr("shape_inputs"); + + for (auto& e : storage_types) { + if (e != mshadow::kFloat32) { + LOG(FATAL) << "ONNX converter does not support types other than float32 " + "right now."; + } + } + + ModelProto model_proto; + // Need to determine IR versions and features to support + model_proto.set_ir_version(static_cast(2)); + GraphProto* graph_proto = model_proto.mutable_graph(); + + std::unordered_map placeholder_shapes = + GetPlaceholderShapes(shape_inputs, ig); + std::unordered_map output_lookup = GetOutputLookup(ig); + uint32_t current_input = 0; + + // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc. + for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) { + const IndexedGraph::Node& node = ig[node_idx]; + const nnvm::Node* source = node.source; + const NodeAttrs& attrs = source->attrs; + const Op* op = source->op(); + + std::string node_name = attrs.name; + // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a + // placeholder + if (source->is_variable()) { + // Is this a placeholder? + if (shared_buffer->count(node_name) == 0) { + // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky. + // Need to figure out how to properly fix it. + if (node_name.find("label") != std::string::npos) { + current_input++; + continue; + } + trt_input_map.emplace(node_name, current_input++); + ConvertPlaceholder(node_name, placeholder_shapes, graph_proto); + } else { + // If it's not a placeholder, then by exclusion it's a constant. + ConvertConstant(graph_proto, node_name, shared_buffer); + } // is_placeholder + } else { + // It's an op, rather than a "variable" (constant or placeholder) + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + if (converter_map.count(op->name) == 0) { + LOG(FATAL) << "Conversion for node of type " << op->name << " (node " + << node_name << ") " + << " is not supported yet."; + } + // Find function ptr to a converter based on the op name, and invoke the converter. This + // looks unsafe because find may not succeed, but it does because we're in the operator + // logic after testing that this node name does not represent a variable. + converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs); + // Add all inputs to the current node (i.e. add graph edges) + for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) { + std::string in_node_name = ig[entry.node_id].source->attrs.name; + // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less + // hacky way to do it than name matching. + if (in_node_name.find("label") != std::string::npos) { + continue; + } + node_proto->add_input(in_node_name); + } + // The node's output will have the same name as the node name. + node_proto->add_output(node_name); + // See if the current node is an output node + auto out_iter = output_lookup.find(node_name); + // We found an output + if (out_iter != output_lookup.end()) { + ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g, + storage_types, dtypes); + } // output found + } // conversion function exists + } // loop over i from 0 to num_nodes + + model_proto.SerializeToString(&trt_param.serialized_onnx_graph); + common::Serialize(trt_input_map, + &trt_param.serialized_input_map); + common::Serialize(trt_output_map, + &trt_param.serialized_output_map); + +#if MXNET_USE_TENSORRT_ONNX_CHECKER + onnx::checker::check_model(model_proto); +#endif // MXNET_USE_TENSORRT_ONNX_CHECKER + + return trt_param; +} + +void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& conv_param = nnvm::get(attrs.parsed); + + node_proto->set_op_type("Conv"); + + const TShape kernel = conv_param.kernel; + const TShape stride = conv_param.stride; + const TShape dilate = conv_param.dilate; + const TShape pad = conv_param.pad; + const uint32_t num_group = conv_param.num_group; + // const bool no_bias = conv_param.no_bias; + const dmlc::optional layout = conv_param.layout; + + // kernel shape + AttributeProto* const kernel_shape = node_proto->add_attribute(); + kernel_shape->set_name("kernel_shape"); + kernel_shape->set_type(AttributeProto::INTS); + + for (const dim_t kval : kernel) { + kernel_shape->add_ints(static_cast(kval)); + } + + // pads + AttributeProto* const pads = node_proto->add_attribute(); + pads->set_name("pads"); + pads->set_type(AttributeProto::INTS); + + for (const dim_t kval : pad) { + pads->add_ints(static_cast(kval)); + pads->add_ints(static_cast(kval)); + } + + // dilations + AttributeProto* const dilations = node_proto->add_attribute(); + dilations->set_name("dilations"); + dilations->set_type(AttributeProto::INTS); + for (const dim_t kval : dilate) { + dilations->add_ints(static_cast(kval)); + } + + // strides + AttributeProto* const strides = node_proto->add_attribute(); + strides->set_name("strides"); + strides->set_type(AttributeProto::INTS); + for (const dim_t kval : stride) { + strides->add_ints(static_cast(kval)); + } + + // group + AttributeProto* const group = node_proto->add_attribute(); + group->set_name("group"); + group->set_type(AttributeProto::INT); + group->set_i(static_cast(num_group)); +} // end ConvertConvolution + +void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& pooling_param = nnvm::get(attrs.parsed); + + const TShape kernel = pooling_param.kernel; + const TShape stride = pooling_param.stride; + const TShape pad = pooling_param.pad; + const int pool_type = pooling_param.pool_type; + const bool global_pool = pooling_param.global_pool; + + if (global_pool) { + if (pool_type == 0) { + node_proto->set_op_type("GlobalMaxPool"); + } else { + node_proto->set_op_type("GlobalAveragePool"); + } + return; + } + + // kernel_shape + AttributeProto* const kernel_shape = node_proto->add_attribute(); + kernel_shape->set_name("kernel_shape"); + kernel_shape->set_type(AttributeProto::INTS); + for (int kval : kernel) { + kernel_shape->add_ints(static_cast(kval)); + } + + // pads + AttributeProto* const pads = node_proto->add_attribute(); + pads->set_name("pads"); + pads->set_type(AttributeProto::INTS); + for (int kval : pad) { + pads->add_ints(static_cast(kval)); + } + + // strides + AttributeProto* const strides = node_proto->add_attribute(); + strides->set_name("strides"); + strides->set_type(AttributeProto::INTS); + for (int kval : stride) { + strides->add_ints(static_cast(kval)); + } + + if (pool_type == 0) { + node_proto->set_op_type("MaxPool"); + } else { + node_proto->set_op_type("AveragePool"); + } // average pooling + // not global pooling +} // end ConvertPooling + +void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& act_param = nnvm::get(attrs.parsed); + std::string act_type; + switch (act_param.act_type) { + case op::activation::kReLU: + act_type = "Relu"; + break; + case op::activation::kSigmoid: + act_type = "Sigmoid"; + break; + case op::activation::kTanh: + act_type = "Tanh"; + break; + case op::activation::kSoftReLU: + // act_type = "SoftReLU"; + throw dmlc::Error("SoftReLU is not supported in ONNX"); + break; + default: + throw dmlc::Error("Activation of such type doesn't exist"); + } + + node_proto->set_op_type(act_type); +} + +void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& act_param = nnvm::get(attrs.parsed); + if (act_param.no_bias) { + node_proto->set_op_type("MatMul"); + } else { + node_proto->set_op_type("Gemm"); + + AttributeProto* const alpha = node_proto->add_attribute(); + alpha->set_name("alpha"); + alpha->set_type(AttributeProto::FLOAT); + alpha->set_f(1.0f); + + AttributeProto* const beta = node_proto->add_attribute(); + beta->set_name("beta"); + beta->set_type(AttributeProto::FLOAT); + beta->set_f(1.0f); + + AttributeProto* const broadcast = node_proto->add_attribute(); + broadcast->set_name("broadcast"); + broadcast->set_type(AttributeProto::INT); + broadcast->set_i(1); + + AttributeProto* const transA = node_proto->add_attribute(); + transA->set_name("transA"); + transA->set_type(AttributeProto::INT); + transA->set_i(0); + + AttributeProto* const transB = node_proto->add_attribute(); + transB->set_name("transB"); + transB->set_type(AttributeProto::INT); + transB->set_i(1); + } +} + +void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Softmax"); + + // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its + // node params. This attribute is only relevant when the input is coerced to 2D, and in that + // case dimension 0 is assumed to be the batch dimension. + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); +} + +void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Flatten"); + + // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its + // node params. This attribute is only relevant when the input is coerced to 2D, and in that + // case dimension 0 is assumed to be the batch dimension. + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); +} + +void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("BatchNormalization"); + const auto& param = nnvm::get(attrs.parsed); + + AttributeProto* const epsilon = node_proto->add_attribute(); + epsilon->set_name("epsilon"); + epsilon->set_type(AttributeProto::FLOAT); + epsilon->set_f(static_cast(param.eps)); + + AttributeProto* const is_test = node_proto->add_attribute(); + is_test->set_name("is_test"); + is_test->set_type(AttributeProto::INT); + is_test->set_i(1); + + AttributeProto* const momentum = node_proto->add_attribute(); + momentum->set_name("momentum"); + momentum->set_type(AttributeProto::FLOAT); + momentum->set_f(param.momentum); + + AttributeProto* const spatial = node_proto->add_attribute(); + spatial->set_name("spatial"); + spatial->set_type(AttributeProto::INT); + spatial->set_i(1); + + AttributeProto* const consumed = node_proto->add_attribute(); + consumed->set_name("consumed_inputs"); + consumed->set_type(AttributeProto::INTS); + + for (int i = 0; i < 5; i++) { + int val = (i < 3) ? 0 : 1; + consumed->add_ints(static_cast(val)); + } +} + +void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Add"); + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); + + AttributeProto* const broadcast = node_proto->add_attribute(); + broadcast->set_name("broadcast"); + broadcast->set_type(AttributeProto::INT); + broadcast->set_i(0); // 1 +} + +std::unordered_map GetPlaceholderShapes( + const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) { + std::unordered_map placeholder_shapes; + for (uint32_t i = 0; i < shape_inputs.size(); ++i) { + std::string name = ig[ig.input_nodes()[i]].source->attrs.name; + TShape shp = shape_inputs[i]; + if (shp.ndim() > 0) { + placeholder_shapes.emplace(name, shp); + } + } + return placeholder_shapes; +} + +std::unordered_map GetOutputLookup( + const nnvm::IndexedGraph& ig) { + std::unordered_map output_lookup; + const std::vector& graph_outputs = + ig.outputs(); + for (uint32_t i = 0; i < graph_outputs.size(); ++i) { + const uint32_t id = graph_outputs[i].node_id; + const IndexedGraph::Node ig_node = ig[id]; + const nnvm::Node* const source = ig_node.source; + const std::string name = source->attrs.name; + output_lookup.emplace(name, i); + } + return output_lookup; +} + +void ConvertPlaceholder( + const std::string& node_name, + const std::unordered_map& placeholder_shapes, + GraphProto* const graph_proto) { + auto val_info_proto = graph_proto->add_input(); + auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type(); + auto shape_proto = type_proto->mutable_shape(); + + val_info_proto->set_name(node_name); + // Will support fp16, etc. in the near future + type_proto->set_elem_type(TensorProto_DataType_FLOAT); + auto entry_shape = placeholder_shapes.find(node_name)->second; + + for (const auto& elem : entry_shape) { + TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim(); + tsp_dim->set_dim_value(static_cast(elem)); + } +} + +void ConvertConstant( + GraphProto* const graph_proto, const std::string& node_name, + std::unordered_map* const shared_buffer) { + NodeProto* const node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + node_proto->add_output(node_name); + node_proto->set_op_type("Constant"); + + const NDArray nd = shared_buffer->find(node_name)->second; + const TBlob& blob = nd.data(); + const TShape shape = blob.shape_; + const int32_t size = shape.Size(); + + std::shared_ptr shared_data_ptr(new float[size]); + float* const data_ptr = shared_data_ptr.get(); + nd.SyncCopyToCPU(static_cast(data_ptr), size); + + AttributeProto* const tensor_attr = node_proto->add_attribute(); + tensor_attr->set_name("value"); + tensor_attr->set_type(AttributeProto::TENSOR); + + TensorProto* const tensor_proto = tensor_attr->mutable_t(); + tensor_proto->set_data_type(TensorProto_DataType_FLOAT); + for (auto& dim : shape) { + tensor_proto->add_dims(static_cast(dim)); + } + + for (int blob_idx = 0; blob_idx < size; ++blob_idx) { + tensor_proto->add_float_data(data_ptr[blob_idx]); + } +} + +void ConvertOutput( + op::tensorrt::InferenceMap_t* const trt_output_map, + GraphProto* const graph_proto, + const std::unordered_map::iterator& out_iter, + const std::string& node_name, const nnvm::Graph& g, + const StorageTypeVector& storage_types, const DTypeVector& dtypes) { + const nnvm::IndexedGraph& ig = g.indexed_graph(); + uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]); + TShape out_shape = g.GetAttr("shape")[out_idx]; + int storage_type = storage_types[out_idx]; + int dtype = dtypes[out_idx]; + + // This should work with fp16 as well + op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type, + dtype}; + + trt_output_map->emplace(node_name, out_tuple); + + auto graph_out = graph_proto->add_output(); + auto tensor_type = graph_out->mutable_type()->mutable_tensor_type(); + auto tensor_shape_proto = tensor_type->mutable_shape(); + graph_out->set_name(node_name); + + // Also support fp16. + tensor_type->set_elem_type(TensorProto_DataType_FLOAT); + + for (int64_t dim_shp : out_shape) { + TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim(); + tsp_dim->set_dim_value(static_cast(dim_shp)); + } +} + +} // namespace nnvm_to_onnx +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h new file mode 100644 index 000000000000..be335ab1208f --- /dev/null +++ b/src/operator/contrib/tensorrt-inl.h @@ -0,0 +1,113 @@ +#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ +#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT Operator + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../operator_common.h" +#include "../../common/utils.h" +#include "../../common/serialization.h" +#include "../../executor/exec_pass.h" +#include "../../executor/graph_executor.h" +#include "../../executor/onnx_to_tensorrt.h" + +namespace mxnet { +namespace op { + +using namespace nnvm; +using namespace ::onnx; +using int64 = ::google::protobuf::int64; + +namespace tensorrt { + enum class TypeIO { Inputs = 0, Outputs = 1 }; + using NameToIdx_t = std::map; + using InferenceTuple_t = std::tuple; + using InferenceMap_t = std::map; +} // namespace tensorrt + +using trt_name_to_idx = std::map; + +struct TRTParam : public dmlc::Parameter { + std::string serialized_onnx_graph; + std::string serialized_input_map; + std::string serialized_output_map; + tensorrt::NameToIdx_t input_map; + tensorrt::InferenceMap_t output_map; + ::onnx::ModelProto onnx_pb_graph; + + TRTParam() {} + + TRTParam(const ::onnx::ModelProto& onnx_graph, + const tensorrt::InferenceMap_t& input_map, + const tensorrt::NameToIdx_t& output_map) { + common::Serialize(input_map, &serialized_input_map); + common::Serialize(output_map, &serialized_output_map); + onnx_graph.SerializeToString(&serialized_onnx_graph); + } + +DMLC_DECLARE_PARAMETER(TRTParam) { + DMLC_DECLARE_FIELD(serialized_onnx_graph) + .describe("Serialized ONNX graph"); + DMLC_DECLARE_FIELD(serialized_input_map) + .describe("Map from inputs to topological order as input."); + DMLC_DECLARE_FIELD(serialized_output_map) + .describe("Map from outputs to order in g.outputs."); + } +}; + +struct TRTEngineParam { + nvinfer1::IExecutionContext* trt_executor; + std::vector > binding_map; +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc new file mode 100644 index 000000000000..619fe1e2b8f4 --- /dev/null +++ b/src/operator/contrib/tensorrt.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cc + * \brief TensorRT operation registration + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./tensorrt-inl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../common/serialization.h" +#include "../../common/utils.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(TRTParam); + +OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine, + tensorrt::NameToIdx_t input_map, + tensorrt::NameToIdx_t output_map) { + TRTEngineParam param; + for (int b = 0; b < trt_engine->getNbBindings(); ++b) { + const std::string& binding_name = trt_engine->getBindingName(b); + if (trt_engine->bindingIsInput(b)) { + param.binding_map.emplace_back(input_map[binding_name], + tensorrt::TypeIO::Inputs); + } else { + param.binding_map.emplace_back(output_map[binding_name], + tensorrt::TypeIO::Outputs); + } + } + param.trt_executor = trt_engine->createExecutionContext(); + return OpStatePtr::Create(param); +} + +OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/, + const std::vector& /*ishape*/, + const std::vector& /*itype*/) { + const auto& node_param = nnvm::get(attrs.parsed); + + ::onnx::ModelProto model_proto; + bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph); + if (!success) { + LOG(FATAL) << "Problems parsing serialized ONNX model."; + } + auto graph = model_proto.graph(); + auto first_input_type = graph.input(0).type().tensor_type(); + auto dim_value = first_input_type.shape().dim(0).dim_value(); + auto batch_size = static_cast(dim_value); + // Need to set up max workspace size based on device properties + nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx( + node_param.serialized_onnx_graph, batch_size, 1 << 30); + + tensorrt::NameToIdx_t output_map; + for (auto& el : node_param.output_map) { + output_map[el.first] = std::get<0>(el.second); + } + return GetPtrMapping(trt_engine, node_param.input_map, output_map); +} + +void TRTParamParser(nnvm::NodeAttrs* attrs) { + TRTParam param_; + + try { + param_.Init(attrs->dict); + common::Deserialize(¶m_.input_map, param_.serialized_input_map); + common::Deserialize(¶m_.output_map, param_.serialized_output_map); + param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + + attrs->parsed = std::move(param_); +} + +inline bool TRTInferShape(const NodeAttrs& attrs, std::vector* /*in_shape*/, + std::vector* out_shape) { + const auto &node_param = nnvm::get(attrs.parsed); + for (auto& el : node_param.output_map) { + (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second); + } + return true; +} + +inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask*/, + DispatchMode* dispatch_mode, + std::vector* /*in_storage_type*/, + std::vector* out_storage_type) { + return storage_type_assign(out_storage_type, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); +} + +inline bool TRTInferType(const NodeAttrs& attrs, std::vector* /*in_dtype*/, + std::vector* out_dtype) { + const auto& node_param = nnvm::get(attrs.parsed); + for (auto& el : node_param.output_map) { + (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second); + } + return true; +} + +inline std::vector TRTListInputNames(const NodeAttrs& attrs) { + std::vector output; + const auto& node_param = nnvm::get(attrs.parsed); + output.resize(node_param.input_map.size()); + for (auto& el : node_param.input_map) { + output[el.second] = el.first; + } + return output; +} + +inline std::vector TRTListOutputNames(const NodeAttrs& attrs) { + std::vector output; + const auto& node_param = nnvm::get(attrs.parsed); + output.resize(node_param.output_map.size()); + for (auto& el : node_param.output_map) { + output[std::get<0>(el.second)] = el.first; + } + return output; +} + +NNVM_REGISTER_OP(_trt_op) + .describe(R"code(TRT operation (one engine) +)code" ADD_FILELINE) + .set_num_inputs([](const NodeAttrs& attrs) { + const auto& node_param = nnvm::get(attrs.parsed); + return node_param.input_map.size(); + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const auto& node_param = nnvm::get(attrs.parsed); + return node_param.output_map.size(); + }) + .set_attr_parser(TRTParamParser) + .set_attr("FInferShape", TRTInferShape) + .set_attr("FInferType", TRTInferType) + .set_attr("FListInputNames", TRTListInputNames) + .set_attr("FListOutputNames", TRTListOutputNames) + .set_attr("FCreateOpState", TRTCreateState) + .set_attr("FInferStorageType", TRTInferStorageType); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu new file mode 100644 index 000000000000..2fe8727b73e4 --- /dev/null +++ b/src/operator/contrib/tensorrt.cu @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cu + * \brief TensorRT GPU operation + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./tensorrt-inl.h" + +namespace mxnet { +namespace op { + +#define CHECK_CUDART(x) do { \ + cudaError_t res = (x); \ + if (res != cudaSuccess) { \ + fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \ + #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \ + exit(1); \ + } \ +} while (0) + +void TRTCompute(const OpStatePtr& state, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + + Stream* s = ctx.get_stream(); + cudaStream_t cuda_s = Stream::GetStream(s); + const auto& param = state.get_state(); + std::vector bindings; + bindings.reserve(param.binding_map.size()); + for (auto& p : param.binding_map) { + if (p.second == tensorrt::TypeIO::Inputs) { + bindings.emplace_back(inputs[p.first].dptr_); + } else { + bindings.emplace_back(outputs[p.first].dptr_); + } + } + + const int batch_size = static_cast(inputs[0].shape_[0]); + param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr); + CHECK_CUDART(cudaStreamSynchronize(cuda_s)); +} + +NNVM_REGISTER_OP(_trt_op) +.set_attr("FStatefulCompute", TRTCompute); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/tests/.gitignore b/tests/.gitignore index d6459089c245..3e5eed695f0a 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1,2 @@ *_unittest +*.gz diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc new file mode 100644 index 000000000000..96f8b6c3a3a7 --- /dev/null +++ b/tests/cpp/misc/serialization.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include <../../../src/common/serialization.h> + +using namespace mxnet; +using namespace std; + +/* + * Test that used datastruct are properly serialized and deserialized + */ + +TEST(SerializerTest, InputMapCorrect) { + std::map input_map; + input_map.emplace("input_0", 2); + input_map.emplace("another_input", 0); + input_map.emplace("last_input", 1); + std::string serialized_data; + common::Serialize(input_map, &serialized_data); + std::map deserialized_input_map; + common::Deserialize(&deserialized_input_map, serialized_data); + ASSERT_EQ(input_map.size(), deserialized_input_map.size()); + for (auto& p : input_map) { + auto it = deserialized_input_map.find(p.first); + ASSERT_NE(it, deserialized_input_map.end()); + ASSERT_EQ(it->second, p.second); + } +} + +TEST(SerializerTest, OutputMapCorrect) { + std::map > output_map; + output_map.emplace("output_0", std::make_tuple(1, TShape({23, 12, 63, 432}), 0, 1)); + output_map.emplace("another_output", std::make_tuple(2, TShape({23, 123}), 14, -23)); + output_map.emplace("last_output", std::make_tuple(0, TShape({0}), -1, 0)); + std::string serialized_data; + common::Serialize(output_map, &serialized_data); + std::map > deserialized_output_map; + common::Deserialize(&deserialized_output_map, serialized_data); + ASSERT_EQ(output_map.size(), deserialized_output_map.size()); + for (auto& p : output_map) { + auto it = deserialized_output_map.find(p.first); + ASSERT_NE(it, deserialized_output_map.end()); + auto lhs = it->second; + auto rhs = p.second; + ASSERT_EQ(std::get<0>(lhs), std::get<0>(rhs)); + ASSERT_EQ(std::get<1>(lhs), std::get<1>(rhs)); + ASSERT_EQ(std::get<2>(lhs), std::get<2>(rhs)); + ASSERT_EQ(std::get<3>(lhs), std::get<3>(rhs)); + } +} + diff --git a/tests/python/tensorrt/common.py b/tests/python/tensorrt/common.py new file mode 100644 index 000000000000..eb599f69973c --- /dev/null +++ b/tests/python/tensorrt/common.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from ctypes.util import find_library + + +def check_tensorrt_installation(): + assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library" + + +def merge_dicts(*dict_args): + """Merge arg_params and aux_params to populate shared_buffer""" + result = {} + for dictionary in dict_args: + result.update(dictionary) + return result + + +def get_fp16_infer_for_fp16_graph(): + return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0)) + + +def set_fp16_infer_for_fp16_graph(status=False): + os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status)) diff --git a/tests/python/tensorrt/lenet5_common.py b/tests/python/tensorrt/lenet5_common.py new file mode 100644 index 000000000000..347d6f3c11ba --- /dev/null +++ b/tests/python/tensorrt/lenet5_common.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import mxnet as mx +from common import * + +def get_iters(mnist, batch_size): + """Get MNIST iterators.""" + train_iter = mx.io.NDArrayIter(mnist['train_data'], + mnist['train_label'], + batch_size, + shuffle=True) + val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + all_test_labels = np.array(mnist['test_label']) + return train_iter, val_iter, test_iter, all_test_labels diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py new file mode 100644 index 000000000000..8edd9abf70e7 --- /dev/null +++ b/tests/python/tensorrt/lenet5_train.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import mxnet as mx +from lenet5_common import get_iters + + +def lenet5(): + """LeNet-5 Symbol""" + #pylint: disable=no-member + data = mx.sym.Variable('data') + conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) + tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") + pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # second conv + conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) + tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") + pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # first fullc + flatten = mx.sym.Flatten(data=pool2) + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500) + tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") + # second fullc + fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) + # loss + lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + #pylint: enable=no-member + return lenet + + +def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter): + """train LeNet-5 model on MNIST data""" + ctx = mx.gpu(0) + lenet_model = mx.mod.Module(lenet5(), context=ctx) + + lenet_model.fit(train_iter, + eval_data=val_iter, + optimizer='sgd', + optimizer_params={'learning_rate': 0.1, 'momentum': 0.9}, + eval_metric='acc', + batch_end_callback=mx.callback.Speedometer(batch_size, 1), + num_epoch=num_epochs) + + # predict accuracy for lenet + acc = mx.metric.Accuracy() + lenet_model.score(test_iter, acc) + accuracy = acc.get()[1] + assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low" + return lenet_model + + +if __name__ == '__main__': + num_epochs = 10 + batch_size = 128 + model_name = 'lenet5' + model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") + model_file = '%s/%s-symbol.json' % (model_dir, model_name) + params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) + + if not (os.path.exists(model_file) and os.path.exists(params_file)): + mnist = mx.test_utils.get_mnist() + + _, _, _, all_test_labels = get_iters(mnist, batch_size) + + trained_lenet = train_lenet5(num_epochs, batch_size, + *get_iters(mnist, batch_size)[:-1]) + trained_lenet.save_checkpoint(model_name, num_epochs) diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py new file mode 100644 index 000000000000..4fdd522341bc --- /dev/null +++ b/tests/python/tensorrt/test_cvnets.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import gluoncv +import mxnet as mx +import numpy as np + +from mxnet import gluon +from time import time + +from mxnet.gluon.data.vision import transforms + + +def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) + h, w = 32, 32 + net = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + + if use_tensorrt: + out = net(data) + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) + executor = mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params, + data=(batch_size,3, h, w), + softmax_label=(batch_size,), grad_req='null', + force_rebind=True) + else: + # Convert gluon model to Symbolic + net.hybridize() + net.forward(mx.ndarray.zeros((batch_size, 3, h, w))) + net.export(model_name) + symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0) + executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), + softmax_label=(batch_size,)) + executor.copy_params_from(arg_params, aux_params) + return executor + + +def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128): + executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size) + + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + + all_label_test = np.zeros(num_ex) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + ]) + + data_loader = lambda: gluon.data.DataLoader( + gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), + batch_size=batch_size, shuffle=False, num_workers=num_workers) + + val_data = data_loader() + + for idx, (data, label) in enumerate(val_data): + # Skip last batch if it's undersized. + if data.shape[0] < batch_size: + continue + offset = idx * batch_size + all_label_test[offset:offset + batch_size] = label.asnumpy() + + # warm-up, but don't use result + executor.forward(is_train=False, data=data) + executor.outputs[0].wait_to_read() + + gc.collect() + val_data = data_loader() + example_ct = 0 + start = time() + + # if use_tensorrt: + for idx, (data, label) in enumerate(val_data): + # Skip last batch if it's undersized. + if data.shape[0] < batch_size: + continue + executor.forward(is_train=False, data=data) + preds = executor.outputs[0].asnumpy() + offset = idx * batch_size + all_preds[offset:offset + batch_size, :] = preds[:batch_size] + example_ct += batch_size + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum() + duration = time() - start + + return duration, 100.0 * matches / example_ct + + +def run_experiment_for(model_name, batch_size, num_workers): + print("\n===========================================") + print("Model: %s" % model_name) + print("===========================================") + print("*** Running inference using pure MXNet ***\n") + mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, use_tensorrt=False) + print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct)) + print("\n*** Running inference using MXNet + TensorRT ***\n") + trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, use_tensorrt=True) + print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct)) + speedup = mx_duration / trt_duration + print("TensorRT speed-up (not counting compilation): %.2fx" % speedup) + + acc_diff = abs(mx_pct - trt_pct) + print("Absolute accuracy difference: %f" % acc_diff) + return speedup, acc_diff + + +def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1): + original_try_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + models = [ + 'cifar_resnet20_v1', + 'cifar_resnet56_v1', + 'cifar_resnet110_v1', + 'cifar_resnet20_v2', + 'cifar_resnet56_v2', + 'cifar_resnet110_v2', + 'cifar_wideresnet16_10', + 'cifar_wideresnet28_10', + 'cifar_wideresnet40_8', + 'cifar_resnext29_16x64d' + ] + + num_models = len(models) + + speedups = np.zeros(num_models, dtype=np.float32) + acc_diffs = np.zeros(num_models, dtype=np.float32) + + test_start = time() + + for idx, model in enumerate(models): + speedup, acc_diff = run_experiment_for(model, batch_size, num_workers) + speedups[idx] = speedup + acc_diffs[idx] = acc_diff + assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % ( + tolerance, model) + + print("Perf and correctness checks run on the following models:") + print(models) + mean_speedup = np.mean(speedups) + std_speedup = np.std(speedups) + print("\nSpeedups:") + print(speedups) + print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) + print("Mean speedup: %.2f" % mean_speedup) + print("St. dev. of speedups: %.2f" % std_speedup) + print("\nAcc. differences: %s" % str(acc_diffs)) + + test_duration = time() - test_start + + print("Test duration: %.2f seconds" % test_duration) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_try_value) + + +if __name__ == '__main__': + import nose + + nose.runmodule() diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py new file mode 100644 index 000000000000..25f515a106a6 --- /dev/null +++ b/tests/python/tensorrt/test_cycle.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from common import * + + +def detect_cycle_from(sym, visited, stack): + visited.add(sym.handle.value) + stack.add(sym.handle.value) + for s in sym.get_children(): + if s.handle.value not in visited: + if detect_cycle_from(sym, visited, stack): + return True + elif s.handle.value in stack: + return True + stack.remove(sym.handle.value) + return False + + +def has_no_cycle(sym): + visited = set() + stack = set() + all_nodes = sym.get_internals() + for s in all_nodes: + if s.handle.value in visited: + if detect_cycle_from(s, visited, stack): + return False + return True + + +def test_simple_cycle(): + inp = mx.sym.Variable('input', shape=[1,10]) + A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A') + B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B') + D = mx.sym.sin(data=A, name='D') + C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C') + arg_params = { + 'I_weight': mx.nd.zeros([10,10]), + 'I_bias': mx.nd.zeros([10]), + 'A_weight': mx.nd.zeros([10,10]), + 'A_bias': mx.nd.zeros([10]), + 'B_weight': mx.nd.zeros([10,10]), + 'B_bias': mx.nd.zeros([10]), + } + + executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), + shared_buffer=arg_params, grad_req='null', force_rebind=True) + optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor) + assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle" + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py new file mode 100644 index 000000000000..258686428a45 --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import numpy as np +import mxnet as mx +from common import * +from lenet5_common import get_iters + + +def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt): + """Run inference with either MXNet or TensorRT""" + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) + + data_size = (batch_size,) + mnist['test_data'].shape[1:] + if use_tensorrt: + all_params = merge_dicts(arg_params, aux_params) + executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(0), all_params=all_params, + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) + else: + executor = sym.simple_bind(ctx=mx.gpu(0), + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + + # Get this value from all_test_labels + # Also get classes from the dataset + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + + example_ct = 0 + + for idx, dbatch in enumerate(test_iter): + executor.arg_dict["data"][:] = dbatch.data[0] + executor.forward(is_train=False) + offset = idx*batch_size + extent = batch_size if num_ex - offset > batch_size else num_ex - offset + all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent] + example_ct += extent + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum() + + percentage = 100.0 * matches / example_ct + + return percentage + + +def test_tensorrt_inference(): + """Run LeNet-5 inference comparison between MXNet and TensorRT.""" + original_try_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + check_tensorrt_installation() + mnist = mx.test_utils.get_mnist() + num_epochs = 10 + batch_size = 128 + model_name = 'lenet5' + model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") + model_file = '%s/%s-symbol.json' % (model_dir, model_name) + params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) + + _, _, _, all_test_labels = get_iters(mnist, batch_size) + + # Load serialized MXNet model (model-symbol.json + model-epoch.params) + sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) + + print("LeNet-5 test") + print("Running inference in MXNet") + mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, + batch_size=batch_size, use_tensorrt=False) + + print("Running inference in MXNet-TensorRT") + trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, + batch_size=batch_size, use_tensorrt=True) + + print("MXNet accuracy: %f" % mx_pct) + print("MXNet-TensorRT accuracy: %f" % trt_pct) + + assert abs(mx_pct - trt_pct) < 1e-2, \ + """Diff. between MXNet & TensorRT accuracy too high: + MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_try_value) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py new file mode 100644 index 000000000000..fdac859aef6f --- /dev/null +++ b/tests/python/tensorrt/test_training_warning.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import gluoncv +import mxnet as mx + +from tests.python.unittest.common import assertRaises + + +def test_training_without_trt(): + run_resnet(is_train=True, use_tensorrt=False) + + +def test_inference_without_trt(): + run_resnet(is_train=False, use_tensorrt=False) + + +def test_training_with_trt(): + assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True) + + +def test_inference_with_trt(): + run_resnet(is_train=False, use_tensorrt=True) + + +def run_resnet(is_train, use_tensorrt): + original_trt_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) + ctx = mx.gpu(0) + batch_size = 1 + h = 32 + w = 32 + model_name = 'cifar_resnet20_v1' + resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + out = resnet(data) + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + if is_train: + grad_req = 'write' + else: + grad_req = 'null' + if use_tensorrt: + all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) + mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params, + data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) + else: + softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_trt_value) + + +if __name__ == '__main__': + import nose + nose.runmodule()