From 268e90b418cf914103c2d535f2cb157fa90336b7 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Mon, 18 Jun 2018 09:25:05 -0700 Subject: [PATCH 01/43] [MXNET-703] TensorRT runtime integration Co-authored-by: Clement Fuji-Tsang Co-authored-by: Kellen Sunderland --- .gitmodules | 3 + 3rdparty/onnx-tensorrt | 1 + CMakeLists.txt | 10 + Jenkinsfile | 28 + Makefile | 8 + amalgamation/amalgamation.py | 14 +- .../Dockerfile.build.ubuntu_gpu_tensorrt | 41 ++ ci/docker/install/tensorrt.sh | 45 ++ ci/docker/runtime_functions.sh | 67 ++ include/mxnet/c_api.h | 7 + include/mxnet/executor.h | 16 +- python/mxnet/base.py | 16 + python/mxnet/executor.py | 18 +- python/mxnet/module/executor_group.py | 2 +- src/c_api/c_api_executor.cc | 17 +- src/common/serialization.h | 319 ++++++++++ src/executor/exec_pass.h | 21 + src/executor/graph_executor.cc | 224 +++++-- src/executor/graph_executor.h | 30 +- src/executor/onnx_to_tensorrt.cc | 148 +++++ src/executor/onnx_to_tensorrt.h | 77 +++ src/executor/tensorrt_pass.cc | 589 ++++++++++++++++++ src/operator/contrib/nnvm_to_onnx-inl.h | 156 +++++ src/operator/contrib/nnvm_to_onnx.cc | 527 ++++++++++++++++ src/operator/contrib/tensorrt-inl.h | 113 ++++ src/operator/contrib/tensorrt.cc | 183 ++++++ src/operator/contrib/tensorrt.cu | 73 +++ tests/cpp/misc/serialization.cc | 68 ++ tests/python/tensorrt/common.py | 47 ++ tests/python/tensorrt/lenet5_common.py | 31 + tests/python/tensorrt/lenet5_train.py | 120 ++++ tests/python/tensorrt/test_cycle.py | 64 ++ tests/python/tensorrt/test_tensorrt_lenet5.py | 97 +++ .../test_tensorrt_resnet_resnext_ssd.py | 260 ++++++++ 34 files changed, 3377 insertions(+), 63 deletions(-) create mode 160000 3rdparty/onnx-tensorrt create mode 100755 ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt create mode 100755 ci/docker/install/tensorrt.sh create mode 100644 src/common/serialization.h create mode 100644 src/executor/onnx_to_tensorrt.cc create mode 100644 src/executor/onnx_to_tensorrt.h create mode 100644 src/executor/tensorrt_pass.cc create mode 100644 src/operator/contrib/nnvm_to_onnx-inl.h create mode 100644 src/operator/contrib/nnvm_to_onnx.cc create mode 100644 src/operator/contrib/tensorrt-inl.h create mode 100644 src/operator/contrib/tensorrt.cc create mode 100644 src/operator/contrib/tensorrt.cu create mode 100644 tests/cpp/misc/serialization.cc create mode 100644 tests/python/tensorrt/common.py create mode 100644 tests/python/tensorrt/lenet5_common.py create mode 100644 tests/python/tensorrt/lenet5_train.py create mode 100644 tests/python/tensorrt/test_cycle.py create mode 100644 tests/python/tensorrt/test_tensorrt_lenet5.py create mode 100644 tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py diff --git a/.gitmodules b/.gitmodules index 9aeb1c754983..836d824a6f5a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,6 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm url = https://github.com/dmlc/tvm +[submodule "3rdparty/onnx-tensorrt"] + path = 3rdparty/onnx-tensorrt + url = https://github.com/onnx/onnx-tensorrt.git diff --git a/3rdparty/onnx-tensorrt b/3rdparty/onnx-tensorrt new file mode 160000 index 000000000000..e7be19cff377 --- /dev/null +++ b/3rdparty/onnx-tensorrt @@ -0,0 +1 @@ +Subproject commit e7be19cff377a95817503e8525e20de34cdc574a diff --git a/CMakeLists.txt b/CMakeLists.txt index 483108a68419..e5744249dc0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON) mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF) mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF) +mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF) message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}") if(USE_CUDA AND NOT USE_OLDCMAKECUDA) @@ -185,6 +186,15 @@ if(USE_VTUNE) list(APPEND mxnet_LINKER_LIBS dl) endif() +if(USE_TENSORRT) + message(STATUS "Using TensorRT") + include_directories(3rdparty/onnx-tensorrt/third_party/onnx/build/) + include_directories(3rdparty/onnx-tensorrt/) + include_directories(3rdparty/) + add_definitions(-DMXNET_USE_TENSORRT=1) + add_definitions(-DONNX_NAMESPACE=onnx) +endif() + if(USE_MKLDNN) include(cmake/MklDnn.cmake) # CPU architecture (e.g., C5) can't run on another architecture (e.g., g3). diff --git a/Jenkinsfile b/Jenkinsfile index 6d21f496426e..32b89561aed9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -28,6 +28,7 @@ mx_dist_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3r mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' // timeout in minutes max_time = 120 // assign any caught errors here @@ -372,6 +373,17 @@ try { } } }, + 'TensorRT': { + node('mxnetlinux-cpu') { + ws('workspace/build-tensorrt') { + timeout(time: max_time, unit: 'MINUTES') { + init_git() + docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false) + pack_lib('tensorrt', mx_tensorrt_lib) + } + } + } + }, 'Build CPU windows':{ node('mxnetwindows-cpu') { timeout(time: max_time, unit: 'MINUTES') { @@ -740,6 +752,22 @@ try { } } }, + 'Python3: TensorRT GPU': { + node('mxnetlinux-gpu-p3') { + ws('workspace/build-tensorrt') { + timeout(time: max_time, unit: 'MINUTES') { + try { + init_git() + unpack_lib('tensorrt', mx_tensorrt_lib) + docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) + publish_test_coverage() + } finally { + collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') + } + } + } + } + }, 'Scala: CPU': { node('mxnetlinux-cpu') { ws('workspace/ut-scala-cpu') { diff --git a/Makefile b/Makefile index 88f7dd9278cb..b794e00f00aa 100644 --- a/Makefile +++ b/Makefile @@ -91,6 +91,14 @@ else endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) + + +ifeq ($(USE_TENSORRT), 1) + CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1 + LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin +endif +# -L/usr/local/lib + ifeq ($(DEBUG), 1) NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) else diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py index 52d775b76925..a3c28f7118e9 100644 --- a/amalgamation/amalgamation.py +++ b/amalgamation/amalgamation.py @@ -23,13 +23,12 @@ import platform blacklist = [ - 'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', - 'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', - 'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h', - 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h', - 'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', - 'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', - 'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h', + 'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h', + 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h', + 'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h', + 'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h', + 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h', + 'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h', 'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp', 'relacy_shims.h', 'ittnotify.h', 'shared_mutex' ] @@ -150,6 +149,7 @@ def expand(x, pending, stage): h not in sysheaders and 'mkl' not in h and 'nnpack' not in h and + 'tensorrt' not in h and not h.endswith('.cuh')): sysheaders.append(h) else: expand.treeDepth += 1 diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt new file mode 100755 index 000000000000..255da316041f --- /dev/null +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt @@ -0,0 +1,41 @@ +# -*- mode: dockerfile -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Dockerfile to run MXNet on Ubuntu 16.04 for CPU + +FROM nvidia/cuda:9.0-cudnn7-devel + +WORKDIR /work/deps + +COPY install/ubuntu_core.sh /work/ +RUN /work/ubuntu_core.sh +COPY install/deb_ubuntu_ccache.sh /work/ +RUN /work/deb_ubuntu_ccache.sh +COPY install/ubuntu_python.sh /work/ +RUN /work/ubuntu_python.sh +COPY install/tensorrt.sh /work +RUN /work/tensorrt.sh + +ARG USER_ID=0 +COPY install/ubuntu_adduser.sh /work/ +RUN /work/ubuntu_adduser.sh + +COPY runtime_functions.sh /work/ + +WORKDIR /work/mxnet +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib diff --git a/ci/docker/install/tensorrt.sh b/ci/docker/install/tensorrt.sh new file mode 100755 index 000000000000..a6258d94f62f --- /dev/null +++ b/ci/docker/install/tensorrt.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Install gluoncv since we're testing Gluon models as well +pip2 install gluoncv==0.2.0 +pip3 install gluoncv==0.2.0 + +# Install Protobuf +# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt) +pushd . +cd .. +apt-get update +apt-get install -y automake libtool +git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git +cd protobuf +./autogen.sh +./configure +make -j$(nproc) +make install +ldconfig +popd + +# Install TensorRT +echo "TensorRT build enabled. Installing TensorRT." +wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb +dpkg -i tensorrt.deb +apt-get update +apt-get install -y --allow-downgrades libnvinfer-dev +rm tensorrt.deb diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index a0795eb58a5a..6e1bd5b0449c 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -436,6 +436,62 @@ build_ubuntu_gpu() { build_ubuntu_gpu_cuda91_cudnn7 } +build_ubuntu_gpu_tensorrt() { + + set -ex + + build_ccache_wrappers + + # Build ONNX + pushd . + echo "Installing ONNX." + cd 3rdparty/onnx-tensorrt/third_party/onnx + rm -rf build + mkdir -p build + cd build + cmake \ + -DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\ + -DBUILD_SHARED_LIBS=ON ..\ + -G Ninja + ninja -v + export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH + export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH + popd + + # Build ONNX-TensorRT + pushd . + cd 3rdparty/onnx-tensorrt/ + mkdir -p build + cd build + cmake .. + make -j$(nproc) + export LIBRARY_PATH=`pwd`:$LIBRARY_PATH + popd + + mkdir -p /work/mxnet/lib/ + cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/ + cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/ + cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/ + + rm -rf build + make \ + DEV=1 \ + USE_BLAS=openblas \ + USE_CUDA=1 \ + USE_CUDA_PATH=/usr/local/cuda \ + USE_CUDNN=1 \ + USE_OPENCV=0 \ + USE_DIST_KVSTORE=0 \ + USE_TENSORRT=1 \ + USE_JEMALLOC=0 \ + USE_GPERFTOOLS=0 \ + ONNX_NAMESPACE=onnx \ + CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\ + -j$(nproc) + + report_ccache_usage +} + build_ubuntu_gpu_mkldnn() { set -ex @@ -638,6 +694,15 @@ unittest_ubuntu_python3_gpu_nocudnn() { nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu } +unittest_ubuntu_tensorrt_gpu() { + set -ex + export PYTHONPATH=./python/ + export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 + export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH + python tests/python/tensorrt/lenet5_train.py + nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/ +} + # quantization gpu currently only runs on P3 instances # need to separte it from unittest_ubuntu_python2_gpu() unittest_ubuntu_python2_quantization_gpu() { @@ -970,3 +1035,5 @@ EOF declare -F | cut -d' ' -f3 echo fi + + diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 75147cfd706d..58b1b1b4dafe 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1714,6 +1714,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping, NDArrayHandle** aux_states, ExecutorHandle shared_exec, ExecutorHandle *out); + +/*! + * \brief get optimized graph from graph executor + */ +MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, + SymbolHandle *out); + /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 842653f86537..11d3c5b0127b 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -152,14 +152,14 @@ class Executor { static Executor* SimpleBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::unordered_map& arg_shape_map, - const std::unordered_map& arg_dtype_map, - const std::unordered_map& arg_stype_map, - const std::vector& grad_req_types, - const std::unordered_set& param_names, + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + std::unordered_set* param_names, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 4df794bdfe37..5f4ae8bd0ac9 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -709,3 +709,19 @@ def write_all_str(module_file, module_all_list): module_op_file.close() write_all_str(module_internal_file, module_internal_all) module_internal_file.close() + +def cint(init_val=0): + """create a C int with an optional initial value""" + return C.c_int(init_val) + +def int_addr(x): + """given a c_int, return it's address as an int ptr""" + x_addr = C.addressof(x) + int_p = C.POINTER(C.c_int) + x_int_addr = C.cast(x_addr, int_p) + return x_int_addr + +def checked_call(f, *args): + """call a cuda function and check for success""" + error_t = f(*args) + assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index c0272c5bb433..adb532ddf4fa 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -24,8 +24,9 @@ import ctypes import copy import numpy as np +import mxnet as mx from .base import _LIB -from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str +from .base import mx_uint, NDArrayHandle, ExecutorHandle, SymbolHandle, py_str from .base import check_call, c_handle_array, c_array_buf, c_str_array from .ndarray import NDArray from .ndarray import _ndarray_cls @@ -73,6 +74,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx): self.aux_arrays = [] self.outputs = self._get_outputs() self._symbol = copy.deepcopy(symbol) + self._optimized_symbol = None self._arg_dict = None self._grad_dict = None self._aux_dict = None @@ -323,6 +325,20 @@ def output_dict(self): self._symbol.list_outputs(), self.outputs) return self._output_dict + @property + def optimized_symbol(self): + """Get optimized symbol. + + Returns + ------- + symbol : nnvm::Symbol + The nnvm symbol optimized. + """ + if self._optimized_symbol is None: + handle = SymbolHandle() + check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(handle))) + return mx.sym.Symbol(handle=handle) + def copy_params_from(self, arg_params, aux_params=None, allow_extra_params=False): """Copy parameters from arg_params, aux_params into executor's internal array. diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 5d8e95077c40..c4050699bd52 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -592,8 +592,8 @@ def backward(self, out_grads=None): # pylint: disable=no-member og_my_slice = nd.slice_axis(grad, axis=axis, begin=islice.start, end=islice.stop) - # pylint: enable=no-member out_grads_slice.append(og_my_slice.as_in_context(self.contexts[i])) + # pylint: enable=no-member else: out_grads_slice.append(grad.copyto(self.contexts[i])) exec_.backward(out_grads=out_grads_slice) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 09bc23934e5a..723677ccc38f 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -26,6 +26,7 @@ #include #include #include "./c_api_common.h" +#include "../executor/graph_executor.h" int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); @@ -440,9 +441,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector arg_grad_vec; std::vector aux_state_vec; - *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, - grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + *out = Executor::SimpleBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, + &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map, + &grad_req_type_vec, &shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); @@ -597,6 +598,16 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete out); } +int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, + SymbolHandle *out) { + nnvm::Symbol *s = new nnvm::Symbol(); + API_BEGIN(); + exec::GraphExecutor *exec = static_cast(handle); + *s = exec->GetOptimizedSymbol(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle) { diff --git a/src/common/serialization.h b/src/common/serialization.h new file mode 100644 index 000000000000..56b6069304d0 --- /dev/null +++ b/src/common/serialization.h @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file serialization.h + * \brief Serialization of some STL and nnvm data-structures + * \author Clement Fuji Tsang + */ + +#ifndef MXNET_COMMON_SERIALIZATION_H_ +#define MXNET_COMMON_SERIALIZATION_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace mxnet { +namespace common { + +template +inline size_t SerializedSize(const T &obj); + +template +inline size_t SerializedSize(const nnvm::Tuple &obj); + +template +inline size_t SerializedSize(const std::map &obj); + +template<> +inline size_t SerializedSize(const std::string &obj); + +template +inline size_t SerializedSize(const std::tuple &obj); + +template +inline void Serialize(const T &obj, char **buffer); + +template +inline void Serialize(const nnvm::Tuple &obj, char **buffer); + +template +inline void Serialize(const std::map &obj, char **buffer); + +template<> +inline void Serialize(const std::string &obj, char **buffer); + +template +inline void Serialize(const std::tuple &obj, char **buffer); + +template +inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(std::map *obj, const std::string &buffer, size_t *curr_pos); + +template<> +inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos); + +template +inline void Deserialize(std::tuple *obj, const std::string &buffer, size_t *curr_pos); + + +template +struct is_container { + static const bool value = !std::is_pod::value; +}; + +template +inline size_t SerializedSize(const T &obj) { + return sizeof(T); +} + +template +inline size_t SerializedSize(const nnvm::Tuple &obj) { + if (is_container::value) { + size_t sum_val = 4; + for (auto& el : obj) { + sum_val += SerializedSize(el); + } + return sum_val; + } else { + return 4 + (obj.ndim() * sizeof(T)); + } +} + +template +inline size_t SerializedSize(const std::map &obj) { + size_t sum_val = 4; + if (is_container::value && is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.first) + SerializedSize(p.second); + } + } else if (is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.first); + } + sum_val += sizeof(V) * obj.size(); + } else if (is_container::value) { + for (const auto& p : obj) { + sum_val += SerializedSize(p.second); + } + sum_val += sizeof(K) * obj.size(); + } else { + sum_val += (sizeof(K) + sizeof(V)) * obj.size(); + } + return sum_val; +} + +template<> +inline size_t SerializedSize(const std::string &obj) { + return obj.size() + 4; +} + +template +struct serialized_size_tuple { + template + static inline size_t Compute(const std::tuple &obj) { + return SerializedSize(std::get(obj)) + serialized_size_tuple::Compute(obj); + } +}; + +template<> +struct serialized_size_tuple<0> { + template + static inline size_t Compute(const std::tuple &obj) { + return SerializedSize(std::get<0>(obj)); + } +}; + +template +inline size_t SerializedSize(const std::tuple &obj) { + return serialized_size_tuple::Compute(obj); +} + +// Serializer + +template +inline size_t SerializedContainerSize(const T &obj, char **buffer) { + uint32_t size = obj.size(); + std::memcpy(*buffer, &size, 4); + *buffer += 4; + return (size_t) size; +} + +template +inline void Serialize(const T &obj, char **buffer) { + std::memcpy(*buffer, &obj, sizeof(T)); + *buffer += sizeof(T); +} + +template +inline void Serialize(const nnvm::Tuple &obj, char **buffer) { + uint32_t size = obj.ndim(); + std::memcpy(*buffer, &size, 4); + *buffer += 4; + for (auto& el : obj) { + Serialize(el, buffer); + } +} + +template +inline void Serialize(const std::map &obj, char **buffer) { + SerializedContainerSize(obj, buffer); + for (auto& p : obj) { + Serialize(p.first, buffer); + Serialize(p.second, buffer); + } +} + +template<> +inline void Serialize(const std::string &obj, char **buffer) { + auto size = SerializedContainerSize(obj, buffer); + std::memcpy(*buffer, &obj[0], size); + *buffer += size; +} + +template +struct serialize_tuple { + template + static inline void Compute(const std::tuple &obj, char **buffer) { + serialize_tuple::Compute(obj, buffer); + Serialize(std::get(obj), buffer); + } +}; + +template<> +struct serialize_tuple<0> { + template + static inline void Compute(const std::tuple &obj, char **buffer) { + Serialize(std::get<0>(obj), buffer); + } +}; + +template +inline void Serialize(const std::tuple &obj, char **buffer) { + serialize_tuple::Compute(obj, buffer); +} + +// Deserializer + +template +inline size_t DeserializedContainerSize(T *obj, const std::string &buffer, size_t *curr_pos) { + uint32_t size = obj->size(); + std::memcpy(&size, &buffer[*curr_pos], 4); + *curr_pos += 4; + return (size_t) size; +} + +template +inline void Deserialize(T *obj, const std::string &buffer, size_t *curr_pos) { + std::memcpy(obj, &buffer[*curr_pos], sizeof(T)); + *curr_pos += sizeof(T); +} + +template +inline void Deserialize(nnvm::Tuple *obj, const std::string &buffer, size_t *curr_pos) { + uint32_t size = obj->ndim(); + std::memcpy(&size, &buffer[*curr_pos], 4); + *curr_pos += 4; + obj->SetDim(size); + for (size_t i = 0; i < size; ++i) { + Deserialize((*obj)[i], buffer, curr_pos); + } +} + +template +inline void Deserialize(std::map *obj, const std::string &buffer, size_t *curr_pos) { + auto size = DeserializedContainerSize(obj, buffer, curr_pos); + K first; + for (size_t i = 0; i < size; ++i) { + Deserialize(&first, buffer, curr_pos); + Deserialize(&(*obj)[first], buffer, curr_pos); + } +} + +template<> +inline void Deserialize(std::string *obj, const std::string &buffer, size_t *curr_pos) { + auto size = DeserializedContainerSize(obj, buffer, curr_pos); + obj->resize(size); + std::memcpy(&(obj->front()), &buffer[*curr_pos], size); + *curr_pos += size; +} + +template +struct deserialize_tuple { + template + static inline void Compute(std::tuple *obj, + const std::string &buffer, size_t *curr_pos) { + deserialize_tuple::Compute(obj, buffer, curr_pos); + Deserialize(&std::get(*obj), buffer, curr_pos); + } +}; + +template<> +struct deserialize_tuple<0> { + template + static inline void Compute(std::tuple *obj, + const std::string &buffer, size_t *curr_pos) { + Deserialize(&std::get<0>(*obj), buffer, curr_pos); + } +}; + +template +inline void Deserialize(std::tuple *obj, const std::string &buffer, size_t *curr_pos) { + deserialize_tuple::Compute(obj, buffer, curr_pos); +} + + +template +inline void Serialize(const T& obj, std::string* serialized_data) { + serialized_data->resize(SerializedSize(obj)); + char* curr_pos = &(serialized_data->front()); + Serialize(obj, &curr_pos); + CHECK_EQ((int64_t)curr_pos - (int64_t)&(serialized_data->front()), + serialized_data->size()); +} + +template +inline void Deserialize(T* obj, const std::string& serialized_data) { + size_t curr_pos = 0; + Deserialize(obj, serialized_data, &curr_pos); + CHECK_EQ(curr_pos, serialized_data.size()); +} + +} // namespace common +} // namespace mxnet +#endif // MXNET_COMMON_SERIALIZATION_H_ diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 26a249118940..8a00b2e8cd24 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -198,6 +198,27 @@ Graph InferStorageType(Graph&& graph, StorageTypeVector&& storage_type_inputs = StorageTypeVector(), const std::string& storage_type_attr_key = ""); +/*! \brief The default storage type inference function, which assigns all undefined + * storage types to kDefaultStorage. If all of input and output storage types + * are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise, + * DispatchMode::kFComputeFallback is assigned to dispatch_mode. + */ +bool DefaultStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *iattr, + std::vector *oattr); + +/*! + * \brief Replace subgraphs by TRT (forward only) + */ +Graph ReplaceSubgraph(Graph&& g, + const std::unordered_set& set_subgraph, + std::unordered_map* const params_map); + +std::vector> GetTrtCompatibleSubsets(const Graph& g, + std::unordered_map* const params_map); + } // namespace exec } // namespace mxnet diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7386de4d12e3..d06ffc3cdd71 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -34,6 +34,13 @@ #include "../common/utils.h" #include "../common/exec_utils.h" +#if MXNET_USE_TENSORRT +#include +#include +#include "./onnx_to_tensorrt.h" +#include "../operator/contrib/tensorrt-inl.h" +#endif // MNET_USE_TENSORRT + namespace mxnet { namespace exec { @@ -782,8 +789,13 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { - EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], - inferred_dtype, aux_state_vec); + auto it = shared_buffer->find(arg_name); + if ( it != shared_buffer->end() ) { + aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top]))); + } else { + aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + aux_state_ctxes[aux_top], inferred_dtype))); + } } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); @@ -843,9 +855,25 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // !shared_arg_names.count(arg_name) // model parameter, row_sparse ndarray sharing enabled bool enable_row_sparse_sharing = true; - in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, - inferred_stype, in_arg_ctxes[arg_top], - shared_buffer, enable_row_sparse_sharing)); + if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + #if MXNET_USE_TENSORRT + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); + } else { + in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + in_arg_ctxes[arg_top], inferred_dtype))); + } + #else + LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set, but MXNet wasn't " + << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk"; + #endif + } else { + in_arg_vec->emplace_back(ReshapeOrCreate( + arg_name, inferred_shape, inferred_dtype, inferred_stype, + in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing)); + } + // gradient for model parameter, row_sparse ndarray sharing disabled if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); @@ -941,6 +969,114 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, this->InitOpSegs(); } +/*! + * \brief This function is triggered after each tensorrt subgraph replacement pass. + * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases) + * are absorbed into the TRT engine it also it rerun attributes inferences accordingly + * to the new topology. + */ +Graph GraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::vector *grad_req_types, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::unordered_map *params_map) { + std::unordered_set to_remove_params; + for (auto& el : *params_map) { + to_remove_params.insert(el.first); + } + + DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) { + to_remove_params.erase(n->attrs.name); + }); + + for (auto& el : to_remove_params) { + params_map->erase(el); + arg_shape_map->erase(el); + arg_dtype_map->erase(el); + arg_stype_map->erase(el); + } + const auto &idx = g.indexed_graph(); + num_forward_inputs_ = idx.input_nodes().size(); + in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + aux_state_ctxes->resize(idx.mutable_input_nodes().size()); + + // create "device" and "context" attrs for the graph + g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types, num_forward_inputs_, + num_forward_outputs_); + + // get number of nodes used in forward pass + num_forward_nodes_ = 0; + for (size_t i = 0; i < num_forward_outputs_; ++i) { + num_forward_nodes_ = std::max( + num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); + } + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string &name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { + arg_stypes[i] = it3->second; + } + } + g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } + + g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + + return g; +} + +/*! + * \brief Return the "optimized" symbol contained in _graph. + * For optimization pass such as TensorRT pass + */ +nnvm::Symbol GraphExecutor::GetOptimizedSymbol() { + Symbol ret; + ret.outputs = std::vector(graph_.outputs.begin(), + graph_.outputs.begin() + num_forward_outputs_); + ret = ret.Copy(); + static const Op* trt_op = Op::Get("_trt_op"); + DFSVisit(ret.outputs, [](const nnvm::NodePtr n) { + if (n->op() == trt_op) { + n->attrs.dict.clear(); + } + }); + return ret; +} + /*! * \brief GraphExecutor initializer for simple bind flow in * which only certain input shapes and dtypes are provided by users. @@ -958,22 +1094,23 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, void GraphExecutor::Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::unordered_map& arg_shape_map, - const std::unordered_map& arg_dtype_map, - const std::unordered_map& arg_stype_map, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + std::unordered_set* shared_arg_names, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, std::unordered_map* shared_buffer, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { - nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, - aux_state_ctxes, grad_req_types); + symbol = symbol.Copy(); + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types); // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation // should do this differently. @@ -987,16 +1124,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol, for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; - auto it1 = arg_shape_map.find(name); - if (arg_shape_map.end() != it1) { + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { arg_shapes[i] = it1->second; } - auto it2 = arg_dtype_map.find(name); - if (arg_dtype_map.end() != it2) { + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { arg_dtypes[i] = it2->second; } - auto it3 = arg_stype_map.find(name); - if (arg_stype_map.end() != it3) { + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { arg_stypes[i] = it3->second; } } @@ -1018,20 +1155,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g.GetAttr("storage_type")); } + if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + #if MXNET_USE_TENSORRT + auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); + for (auto trt_group : trt_groups) { + if (trt_group.size() > 1) { + g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer); + g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map, + arg_stype_map, shared_buffer); + } + } + #else + LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set but MXNet wasn't " + << "built with TensorRT. Add USE_TENSORRT = 1 to config.mk"; + #endif + } + // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. if (nullptr == shared_buffer) { // regular simple bind - InitArguments(idx, g.GetAttr("shape"), + InitArguments(g.indexed_graph(), g.GetAttr("shape"), g.GetAttr("dtype"), g.GetAttr("storage_type"), - in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); + *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, + *grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec - InitArguments(idx, g.GetAttr("shape"), + InitArguments(g.indexed_graph(), g.GetAttr("shape"), g.GetAttr("dtype"), g.GetAttr("storage_type"), - in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - grad_req_types, shared_arg_names, shared_exec, + *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, + *grad_req_types, *shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); } // The above code of shape and dtype inferences and argument @@ -1704,14 +1858,14 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, Executor *Executor::SimpleBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::unordered_map& arg_shape_map, - const std::unordered_map& arg_dtype_map, - const std::unordered_map& arg_stype_map, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + std::unordered_set* shared_arg_names, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index bfc415b4526a..e380344bf7ce 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -91,14 +91,14 @@ class GraphExecutor : public Executor { void Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::unordered_map& arg_shape_map, - const std::unordered_map& arg_dtype_map, - const std::unordered_map& arg_stype_map, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + std::unordered_set* shared_arg_names, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, @@ -117,6 +117,8 @@ class GraphExecutor : public Executor { std::vector* arg_grads, std::vector* aux_states) override; + nnvm::Symbol GetOptimizedSymbol(); + protected: friend class mxnet::Imperative; // Information about operational node @@ -215,6 +217,18 @@ class GraphExecutor : public Executor { // indicate whether there is a backward graph for gradients. bool need_grad_; + + Graph ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::vector* grad_req_types, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::unordered_map *params_map); + // internal graph nnvm::Graph graph_; // operator node diff --git a/src/executor/onnx_to_tensorrt.cc b/src/executor/onnx_to_tensorrt.cc new file mode 100644 index 000000000000..0b4d91be7009 --- /dev/null +++ b/src/executor/onnx_to_tensorrt.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file onnx_to_tensorrt.cc + * \brief TensorRT integration with the MXNet executor + * \author Marek Kolodziej, Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include "./onnx_to_tensorrt.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::cout; +using std::cerr; +using std::endl; + +namespace onnx_to_tensorrt { + +struct InferDeleter { + template + void operator()(T* obj) const { + if ( obj ) { + obj->destroy(); + } + } +}; + +template +inline std::shared_ptr InferObject(T* obj) { + if ( !obj ) { + throw std::runtime_error("Failed to create object"); + } + return std::shared_ptr(obj, InferDeleter()); +} + +std::string onnx_ir_version_string(int64_t ir_version = onnx::IR_VERSION) { + int onnx_ir_major = ir_version / 1000000; + int onnx_ir_minor = ir_version % 1000000 / 10000; + int onnx_ir_patch = ir_version % 10000; + return (std::to_string(onnx_ir_major) + "." + + std::to_string(onnx_ir_minor) + "." + + std::to_string(onnx_ir_patch)); +} + +void PrintVersion() { + cout << "Parser built against:" << endl; + cout << " ONNX IR version: " << onnx_ir_version_string(onnx::IR_VERSION) << endl; + cout << " TensorRT version: " + << NV_TENSORRT_MAJOR << "." + << NV_TENSORRT_MINOR << "." + << NV_TENSORRT_PATCH << endl; +} + +nvinfer1::ICudaEngine* onnxToTrtCtx( + const std::string& onnx_model, + int32_t max_batch_size, + size_t max_workspace_size, + nvinfer1::ILogger::Severity verbosity, + bool debug_builder) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + TRT_Logger trt_logger(verbosity); + auto trt_builder = InferObject(nvinfer1::createInferBuilder(trt_logger)); + auto trt_network = InferObject(trt_builder->createNetwork()); + auto trt_parser = InferObject(nvonnxparser::createParser( + *trt_network, trt_logger)); + ::ONNX_NAMESPACE::ModelProto parsed_model; + // We check for a valid parse, but the main effect is the side effect + // of populating parsed_model + if (!parsed_model.ParseFromString(onnx_model)) { + throw dmlc::Error("Could not parse ONNX from string"); + } + + if ( !trt_parser->parse(onnx_model.c_str(), onnx_model.size()) ) { + int nerror = trt_parser->getNbErrors(); + for ( int i=0; i < nerror; ++i ) { + nvonnxparser::IParserError const* error = trt_parser->getError(i); + if ( error->node() != -1 ) { + ::ONNX_NAMESPACE::NodeProto const& node = + parsed_model.graph().node(error->node()); + cerr << "While parsing node number " << error->node() + << " [" << node.op_type(); + if ( !node.output().empty() ) { + cerr << " -> \"" << node.output(0) << "\""; + } + cerr << "]:" << endl; + if ( static_cast(verbosity) >= \ + static_cast(nvinfer1::ILogger::Severity::kINFO) ) { + cerr << "--- Begin node ---" << endl; + cerr << node.DebugString() << endl; + cerr << "--- End node ---" << endl; + } + } + cerr << "ERROR: " + << error->file() << ":" << error->line() + << " In function " << error->func() << ":\n" + << "[" << static_cast(error->code()) << "] " << error->desc() + << endl; + } + throw dmlc::Error("Cannot parse ONNX into TensorRT Engine"); + } + + bool fp16 = trt_builder->platformHasFastFp16(); + + trt_builder->setMaxBatchSize(max_batch_size); + trt_builder->setMaxWorkspaceSize(max_workspace_size); + if ( fp16 && dmlc::GetEnv("MXNET_TENSORRT_USE_FP16_FOR_FP32", false) ) { + LOG(INFO) << "WARNING: TensorRT using fp16 given original MXNet graph in fp32 !!!"; + trt_builder->setHalf2Mode(true); + } + + trt_builder->setDebugSync(debug_builder); + nvinfer1::ICudaEngine* trt_engine = trt_builder->buildCudaEngine(*trt_network.get()); + return trt_engine; +} + +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT diff --git a/src/executor/onnx_to_tensorrt.h b/src/executor/onnx_to_tensorrt.h new file mode 100644 index 000000000000..259cfce7c332 --- /dev/null +++ b/src/executor/onnx_to_tensorrt.h @@ -0,0 +1,77 @@ +#ifndef MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ +#define MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file onnx_to_tensorrt.h + * \brief TensorRT integration with the MXNet executor + * \author Marek Kolodziej, Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include + +#include "../operator/contrib/tensorrt-inl.h" + +namespace onnx_to_tensorrt { + +class TRT_Logger : public nvinfer1::ILogger { + nvinfer1::ILogger::Severity _verbosity; + std::ostream* _ostream; + public: + TRT_Logger(Severity verbosity = Severity::kWARNING, + std::ostream& ostream = std::cout) + : _verbosity(verbosity), _ostream(&ostream) {} + void log(Severity severity, const char* msg) override { + if ( severity <= _verbosity ) { + time_t rawtime = std::time(0); + char buf[256]; + strftime(&buf[0], 256, + "%Y-%m-%d %H:%M:%S", + std::gmtime(&rawtime)); + const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : + severity == Severity::kERROR ? " ERROR" : + severity == Severity::kWARNING ? "WARNING" : + severity == Severity::kINFO ? " INFO" : + "UNKNOWN"); + (*_ostream) << "[" << buf << " " << sevstr << "] " + << msg + << std::endl; + } + } +}; + +nvinfer1::ICudaEngine* onnxToTrtCtx( + const std::string& onnx_model, + int32_t max_batch_size = 32, + size_t max_workspace_size = 1L << 30, + nvinfer1::ILogger::Severity verbosity = nvinfer1::ILogger::Severity::kWARNING, + bool debug_builder = false); +} // namespace onnx_to_tensorrt + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_EXECUTOR_ONNX_TO_TENSORRT_H_ diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc new file mode 100644 index 000000000000..b2eeec1bf82c --- /dev/null +++ b/src/executor/tensorrt_pass.cc @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt_pass.cc + * \brief Replace TRT compatible subgraphs by TRT engines + * \author Clement Fuji Tsang + */ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include + +#include "../operator/contrib/nnvm_to_onnx-inl.h" +#include "./exec_pass.h" +#include "./onnx_to_tensorrt.h" + +namespace mxnet { +namespace exec { + +using NodePtr = nnvm::NodePtr; + +/*! + * \brief Custom graph class, which will contain bi-directional nodes + * we need to compute DFS and reverse DFS for graph partitioning + */ +class BidirectionalGraph { + public: + struct Node { + nnvm::Node* nnvmptr; + std::vector inputs; + std::vector outputs; + }; + std::vector nodes; + std::unordered_map nnvm2nid; + std::vector outputs; + static const std::unordered_set unconditionalTRTop; + + explicit BidirectionalGraph(const Graph &g) { + auto& idx = g.indexed_graph(); + auto num_nodes = idx.num_nodes(); + nodes.reserve(num_nodes); + nnvm2nid.reserve(num_nodes); + outputs.reserve(idx.outputs().size()); + DFSVisit(g.outputs, [this](const nnvm::NodePtr& n) { + BidirectionalGraph::Node new_node; + new_node.nnvmptr = n.get(); + nnvm2nid[n.get()] = static_cast(nodes.size()); + nodes.emplace_back(std::move(new_node)); + }); + for (const auto& it : nnvm2nid) { + nnvm::Node* nnvmnode = it.first; + uint32_t nid = it.second; + for (auto& n : nnvmnode->inputs) { + uint32_t input_nid = nnvm2nid[n.node.get()]; + nodes[input_nid].outputs.emplace_back(&nodes[nid]); + nodes[nid].inputs.emplace_back(&nodes[input_nid]); + } + } + for (auto& e : g.outputs) { + uint32_t nid = nnvm2nid[e.node.get()]; + outputs.emplace_back(&nodes[nid]); + } + } + + template + void DFS(const std::vector& heads, bool reverse, FVisit fvisit) { + std::unordered_set visited; + std::vector vec(heads.begin(), heads.end()); + visited.reserve(heads.size()); + while (!vec.empty()) { + Node* vertex = vec.back(); + vec.pop_back(); + if (visited.count(vertex) == 0) { + visited.insert(vertex); + fvisit(vertex); + std::vector nexts = reverse ? vertex->inputs : vertex->outputs; + for (Node* node : nexts) { + if (visited.count(node) == 0) { + vec.emplace_back(node); + } + } + } + } + } + + using t_pairset = std::pair, std::unordered_set>; + using t_pairvec = std::pair, std::vector>; + using t_uncomp_map = std::unordered_map>; + + std::unordered_set naive_grow_subgraph(Node* head, + std::unordered_set* set_unused, + t_uncomp_map* uncomp_map) { + std::unordered_set subgraph; + std::unordered_set uncomp_set; + std::deque stack; + stack.emplace_back(head); + while (!stack.empty()) { + Node* vertex = stack.back(); + stack.pop_back(); + if (set_unused->count(vertex) && !uncomp_set.count(vertex)) { + set_unused->erase(vertex); + subgraph.insert(vertex); + uncomp_set.insert((*uncomp_map)[vertex].begin(), (*uncomp_map)[vertex].end()); + for (Node* input : vertex->inputs) { + if (set_unused->count(input) && !uncomp_set.count(input)) { + stack.emplace_back(input); + } + } + for (Node* output : vertex->outputs) { + if (set_unused->count(output) && !uncomp_set.count(output)) { + stack.emplace_back(output); + } + } + } + } + return subgraph; + } + + std::vector> get_subsets( + std::unordered_map* const params_map) { + std::vector> subgraphs; + std::unordered_set set_nonTRTnodes; + std::unordered_set set_allnodes(nodes.size()); + std::vector separation_sets; + for (Node& node : nodes) { + if (!IsTRTCompatible(node.nnvmptr)) { + set_nonTRTnodes.insert(&node); + std::unordered_set in_graph; + std::unordered_set out_graph; + std::vector dummy_head; + dummy_head.emplace_back(&node); + DFS(dummy_head, false, [&out_graph](Node* node) { + out_graph.insert(node); + }); + DFS(dummy_head, true, [&in_graph](Node* node) { + in_graph.insert(node); + }); + separation_sets.emplace_back(std::make_pair(in_graph, out_graph)); + } + set_allnodes.emplace(&node); + } + t_uncomp_map uncomp_map; + std::unordered_set set_TRTnodes; + set_TRTnodes.insert(set_allnodes.begin(), set_allnodes.end()); + for (Node* n : set_nonTRTnodes) { + set_TRTnodes.erase(n); + } + for (Node* n : set_TRTnodes) { + for (t_pairset p : separation_sets) { + if (p.first.count(n)) { + uncomp_map[n].insert(p.second.begin(), p.second.end()); + } else if (p.second.count(n)) { + uncomp_map[n].insert(p.first.begin(), p.first.end()); + } + } + for (Node* nonTRTn : set_nonTRTnodes) { + uncomp_map[n].erase(nonTRTn); + } + } + std::unordered_set set_unused; + set_unused.reserve(set_TRTnodes.size()); + + for (auto& n : set_TRTnodes) { + if (n->nnvmptr->attrs.op != nullptr || params_map->count(n->nnvmptr->attrs.name)) { + set_unused.insert(n); + } + } + std::unordered_set visited; + std::deque stack(outputs.begin(), outputs.end()); + while (!stack.empty()) { + Node* vertex = stack.front(); + stack.pop_front(); + if (!visited.count(vertex)) { + visited.insert(vertex); + if (set_unused.count(vertex)) { + subgraphs.emplace_back(naive_grow_subgraph(vertex, &set_unused, &uncomp_map)); + } + for (Node* input : vertex->inputs) { + stack.emplace_back(input); + } + } + } + + return subgraphs; + } + + + private: + friend class Graph; + + bool IsTRTCompatible(nnvm::Node* nodeptr) { + if (nodeptr->op() == nullptr) { + return true; + } + + const std::string op_name = nodeptr->op()->name; + if (op_name == "Pooling") { + return (nodeptr->attrs.dict.at("pool_type") == "avg" || + nodeptr->attrs.dict.at("pool_type") == "max"); + } + + if (unconditionalTRTop.count(op_name)) { + return true; + } + + if (op_name == "Activation") { + return nodeptr->attrs.dict.at("act_type") == "relu" || + nodeptr->attrs.dict.at("act_type") == "tanh" || + nodeptr->attrs.dict.at("act_type") == "sigmoid"; + } + + return false; + } +}; // class BidirectionalGraph + +/*! + * \brief function which transform std::vector back to Attrs (dmlc::any) + */ +const std::unordered_set BidirectionalGraph::unconditionalTRTop = { + "Convolution", + "BatchNorm", + "elemwise_add", + "elemwise_sub", + "elemwise_mul", + "rsqrt", + "pad", + "Pad", + "mean", + "FullyConnected", + "Flatten", + "SoftmaxOutput", +}; + + +using NodeEntrySet = std::unordered_set; + +/*! + * \brief get the output nodes of the subgraph in the main graph + * \return a vector of the output nodes +*/ +std::vector GetSubgraphOutputs(Graph g, + std::unordered_set set_subgraph) { + std::vector outputs; + NodeEntrySet _outputs; + for (auto& e : g.outputs) { + if (set_subgraph.count(e.node.get())) { + _outputs.insert(e); + } + } + DFSVisit(g.outputs, [&set_subgraph, &_outputs](const nnvm::NodePtr &node){ + if (!set_subgraph.count(node.get())) { + for (auto& e : node->inputs) { + if (set_subgraph.count(e.node.get())) { + _outputs.insert(e); + } + } + } + }); + outputs.insert(outputs.begin(), _outputs.begin(), _outputs.end()); + return outputs; +} + + +/*! + * \brief get the nodes outside of the subgraph for which outputs are used in the subgraph + * \return a vector the nodes +*/ +std::vector GetSubgraphInterfaceNodes(Graph g, + std::unordered_set set_subgraph) { + std::vector inputs; + NodeEntrySet _inputs; + DFSVisit(g.outputs, [&set_subgraph, &_inputs](const nnvm::NodePtr &node){ + if (set_subgraph.count(node.get())) { + for (auto& e : node->inputs) { + if (!set_subgraph.count(e.node.get())) { + _inputs.insert(e); + } + } + } + }); + inputs.insert(inputs.begin(), _inputs.begin(), _inputs.end()); + return inputs; +} + +std::unordered_map GetGraphInputsMap(const Graph& g) { + std::unordered_map outputs; + auto& idx = g.indexed_graph(); + outputs.reserve(idx.num_nodes()); + std::vector input_nodes = idx.input_nodes(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + outputs[input_nodes[i]] = static_cast(i); + } + return outputs; +} + +/*! + * \brief Dummy function which creates a fake TensorRT Node + */ +nnvm::NodePtr ConvertNnvmGraphToOnnx(const nnvm::Graph &g, + std::unordered_map* const params_map) { + auto p = nnvm::Node::Create(); + p->attrs.op = nnvm::Op::Get("_trt_op"); + op::TRTParam trt_param = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(g, params_map); + p->attrs.dict["serialized_output_map"] = trt_param.serialized_output_map; + p->attrs.dict["serialized_input_map"] = trt_param.serialized_input_map; + p->attrs.dict["serialized_onnx_graph"] = trt_param.serialized_onnx_graph; + if (p->op()->attr_parser != nullptr) { + p->op()->attr_parser(&(p->attrs)); + } + return p; +} + +/*! + * \brief Update attributes of the graph (such as some inputs properties) + */ +Graph UpdateSubgraphAttrs(Graph&& subgraph, const Graph& g, + const std::unordered_map& old2new, + const nnvm::NodeEntryMap& main_input_entry_to_sub) { + const auto& idx = g.indexed_graph(); + const auto& sub_idx = subgraph.indexed_graph(); + + const auto& shape = g.GetAttr("shape"); + const auto& dtype = g.GetAttr("dtype"); + const auto& storage_type = g.GetAttr("storage_type"); + const auto& shape_inputs = g.GetAttr("shape_inputs"); + const auto& dtype_inputs = g.GetAttr("dtype_inputs"); + const auto& storage_type_inputs = g.GetAttr("storage_type_inputs"); + + nnvm::ShapeVector sub_shape(sub_idx.num_node_entries()); + nnvm::DTypeVector sub_dtype(sub_idx.num_node_entries()); + StorageTypeVector sub_storage_type(sub_idx.num_node_entries()); + nnvm::ShapeVector sub_shape_inputs(sub_idx.input_nodes().size()); + nnvm::DTypeVector sub_dtype_inputs(sub_idx.input_nodes().size()); + StorageTypeVector sub_storage_type_inputs(sub_idx.input_nodes().size()); + + const std::unordered_map inputsindex2pos = GetGraphInputsMap(g); + const std::unordered_map sub_inputsindex2pos = GetGraphInputsMap(subgraph); + // map attributes from graph to subgraph + for (auto& p : old2new) { + const uint32_t nid = idx.node_id(p.first); + const uint32_t sub_nid = sub_idx.node_id(p.second.get()); + const nnvm::Op* op = sub_idx[sub_nid].source->op(); + if (op == nullptr) { // if it's an input node, there is only one output node entry + const uint32_t sub_i = sub_idx.entry_id(sub_nid, 0); + const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid); + const uint32_t i = idx.entry_id(nid, 0); + + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + sub_shape_inputs[sub_input_i] = shape_inputs[inputsindex2pos.at(nid)]; + sub_dtype_inputs[sub_input_i] = dtype_inputs[inputsindex2pos.at(nid)]; + sub_storage_type_inputs[sub_input_i] = storage_type_inputs[inputsindex2pos.at(nid)]; + + } else { + for (size_t oi = 0; oi < op->num_outputs; ++oi) { + const uint32_t sub_i = sub_idx.entry_id(sub_nid, oi); + const uint32_t i = idx.entry_id(nid, oi); + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + } + } + } + // old2new doesn't contain placeholder / interfaces + for (auto& p : main_input_entry_to_sub) { + nnvm::NodeEntry main_entry = p.first; + nnvm::NodeEntry sub_entry = p.second; + const uint32_t sub_nid = sub_idx.node_id(sub_entry.node.get()); + const uint32_t sub_i = sub_idx.entry_id(sub_entry); + const uint32_t i = idx.entry_id(main_entry); + const uint32_t sub_input_i = sub_inputsindex2pos.at(sub_nid); + sub_shape[sub_i] = shape[i]; + sub_dtype[sub_i] = dtype[i]; + sub_storage_type[sub_i] = storage_type[i]; + sub_shape_inputs[sub_input_i] = sub_shape[sub_i]; + sub_dtype_inputs[sub_input_i] = sub_dtype[sub_i]; + sub_storage_type_inputs[sub_input_i] = sub_storage_type[sub_i]; + } + subgraph.attrs["shape"] = + std::make_shared(std::move(sub_shape)); + subgraph.attrs["dtype"] = + std::make_shared(std::move(sub_dtype)); + subgraph.attrs["storage_type"] = + std::make_shared(std::move(sub_storage_type)); + subgraph.attrs["shape_inputs"] = + std::make_shared(std::move(sub_shape_inputs)); + subgraph.attrs["dtype_inputs"] = + std::make_shared(std::move(sub_dtype_inputs)); + subgraph.attrs["storage_type_inputs"] = + std::make_shared(std::move(sub_storage_type_inputs)); + + return subgraph; +} + +/*! + * \brief Generate a name for a new TRT node, avoid collision if some TRT_nodes are already defined + */ +const std::string GetNewTrtName(const Graph& g, const Graph& subgraph) { + const std::string name_prefix("TRT_node"); + std::unordered_set name_set; + DFSVisit(g.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) { + if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) { + name_set.insert(node->attrs.name); + } + }); + // name inside the subgraph will be avaible as they will be removed + DFSVisit(subgraph.outputs, [&name_set, &name_prefix](const nnvm::NodePtr& node) { + if (node->attrs.name.compare(0, name_prefix.size(), name_prefix) == 0) { + name_set.erase(node->attrs.name); + } + }); + uint32_t name_suffix = 0; + std::string full_name = name_prefix + std::to_string(name_suffix); + while (name_set.count(full_name)) { + full_name = name_prefix + std::to_string(++name_suffix); + } + return full_name; +} + +/*! + * \brief helper function to display what nodes are in a specific subset + */ +void dispNodesSet(Graph g, std::unordered_set s) { + DFSVisit(g.outputs, [&s](const nnvm::NodePtr n){ + if (s.count(n.get())) { + std::cout << " Y " << n->attrs.name << std::endl; + } else { + std::cout << " N " << n->attrs.name << std::endl; + } + }); +} + +/*! + * \brief Replace a set of nodes by a TensorRT node + */ +Graph ReplaceSubgraph(Graph&& g, + const std::unordered_set& set_subgraph, + std::unordered_map* const params_map) { + // Create MXNet subgraph + Graph subgraph; + + const auto sub_outputs_in_main = GetSubgraphOutputs(g, set_subgraph); + subgraph.outputs = sub_outputs_in_main; + std::unordered_map old2new; + std::deque stack; + std::unordered_set visited; + int32_t reservation = set_subgraph.size(); + old2new.reserve(reservation); + visited.reserve(reservation); + + for (auto& n : set_subgraph) { + old2new[n] = std::make_shared(*n); + } + + nnvm::NodeEntryMap main_input_entry_to_sub; + for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) { + auto node = nnvm::Node::Create(); + node->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index); + auto new_e = nnvm::NodeEntry{node, 0, 0}; + main_input_entry_to_sub[e] = new_e; + } + + for (nnvm::NodeEntry& e : subgraph.outputs) { + e.node = old2new[e.node.get()]; + stack.emplace_back(e.node.get()); + } + + while (!stack.empty()) { + auto vertex = stack.front(); + stack.pop_front(); + if (!visited.count(vertex)) { + visited.insert(vertex); + for (auto& e : vertex->inputs) { + auto it = main_input_entry_to_sub.find(e); + if (it != main_input_entry_to_sub.end()) { + e = it->second; + } else { + e.node = old2new[e.node.get()]; + } + stack.emplace_back(e.node.get()); + } + } + } + + DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) { + std::remove_if(node->control_deps.begin(), + node->control_deps.end(), + [&set_subgraph](nnvm::NodePtr n_ptr) { + return !set_subgraph.count(n_ptr.get()); + }); + for (nnvm::NodePtr& n_ptr : node->control_deps) { + n_ptr = old2new[n_ptr.get()]; + } + }); + + subgraph = UpdateSubgraphAttrs(std::move(subgraph), g, old2new, main_input_entry_to_sub); + auto& sub_idx = subgraph.indexed_graph(); + + auto trtnodeptr = ConvertNnvmGraphToOnnx(subgraph, params_map); + trtnodeptr->attrs.name = GetNewTrtName(g, subgraph); + + // Insert new trt node and unplug replaced nodes + std::unordered_map sub_input_entryid_to_main; + for (auto& p : main_input_entry_to_sub) { + sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first; + } + + trtnodeptr->inputs.resize(main_input_entry_to_sub.size()); + { + uint32_t counter = 0; + for (uint32_t i : sub_idx.input_nodes()) { + auto it = sub_input_entryid_to_main.find(sub_idx.entry_id(i, 0)); + if (it != sub_input_entryid_to_main.end()) { + trtnodeptr->inputs[counter++] = it->second; + } + } + } + + nnvm::NodeEntryMap sub_outputs_in_main_to_pos; + for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) { + sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i; + } + DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) { + for (auto& e : n->inputs) { + auto it = sub_outputs_in_main_to_pos.find(e); + if (it != sub_outputs_in_main_to_pos.end()) { + e.index = it->second; + e.node = trtnodeptr; + } + } + }); + + for (auto& output : g.outputs) { + auto it = sub_outputs_in_main_to_pos.find(output); + if (it != sub_outputs_in_main_to_pos.end()) { + output.index = it->second; + output.node = trtnodeptr; + } + } + + Graph new_graph; + new_graph.outputs = g.outputs; + return new_graph; +} + +std::vector> GetTrtCompatibleSubsets(const Graph& g, + std::unordered_map* const params_map) { + BidirectionalGraph biG = BidirectionalGraph(g); + std::vector> subsets = biG.get_subsets(params_map); + std::vector> nnvm_subsets(subsets.size(), + std::unordered_set()); + for (size_t i = 0; i < subsets.size(); ++i) { + nnvm_subsets[i].reserve(subsets[i].size()); + for (auto& n : subsets[i]) { + nnvm_subsets[i].insert(n->nnvmptr); + } + } + return nnvm_subsets; +} + +} // namespace exec +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/contrib/nnvm_to_onnx-inl.h b/src/operator/contrib/nnvm_to_onnx-inl.h new file mode 100644 index 000000000000..58f88b051433 --- /dev/null +++ b/src/operator/contrib/nnvm_to_onnx-inl.h @@ -0,0 +1,156 @@ +#ifndef MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ +#define MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT Operator + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./tensorrt-inl.h" +#include "../operator_common.h" +#include "../../common/utils.h" +#include "../../common/serialization.h" + +namespace mxnet { +namespace op { +namespace nnvm_to_onnx { + +using namespace nnvm; +using namespace ::onnx; +using int64 = ::google::protobuf::int64; + +std::unordered_map GetPlaceholderShapes(const ShapeVector& shape_inputs, + const nnvm::IndexedGraph& ig); + +std::unordered_map GetOutputLookup(const nnvm::IndexedGraph& ig); + +void ConvertPlaceholder( + const std::string& node_name, + const std::unordered_map& placeholder_shapes, + GraphProto* const graph_proto); + +void ConvertConstant(GraphProto* const graph_proto, + const std::string& node_name, + std::unordered_map* const shared_buffer); + +void ConvertOutput(op::tensorrt::InferenceMap_t* const trt_output_map, + GraphProto* const graph_proto, + const std::unordered_map::iterator& out_iter, + const std::string& node_name, + const nnvm::Graph& g, + const StorageTypeVector& storage_types, + const DTypeVector& dtypes); + +typedef void (*ConverterFunction)(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + + +// Forward declarations +void ConvertConvolution( + NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + + +void ConvertPooling(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertActivation(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertFullyConnected(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertSoftmaxOutput(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertFlatten(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertBatchNorm(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +void ConvertElementwiseAdd(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + +TRTParam ConvertNnvmGraphToOnnx( + const nnvm::Graph &g, + std::unordered_map *const shared_buffer); + +static const std::unordered_map converter_map = { + {"Convolution", ConvertConvolution}, + {"Pooling", ConvertPooling}, + {"Activation", ConvertActivation}, + {"FullyConnected", ConvertFullyConnected}, + {"SoftmaxOutput", ConvertSoftmaxOutput}, + {"Flatten", ConvertFlatten}, + {"BatchNorm", ConvertBatchNorm}, + {"elemwise_add", ConvertElementwiseAdd}}; + +} // namespace nnvm_to_onnx +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_OPERATOR_CONTRIB_NNVM_TO_ONNX_INL_H_ diff --git a/src/operator/contrib/nnvm_to_onnx.cc b/src/operator/contrib/nnvm_to_onnx.cc new file mode 100644 index 000000000000..902466614c7c --- /dev/null +++ b/src/operator/contrib/nnvm_to_onnx.cc @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cc + * \brief TensorRT operation registration + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./nnvm_to_onnx-inl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../common/serialization.h" +#include "../../common/utils.h" +#include "../../ndarray/ndarray_function.h" +#include "../../operator/nn/activation-inl.h" +#include "../../operator/nn/batch_norm-inl.h" +#include "../../operator/nn/convolution-inl.h" +#include "../../operator/nn/fully_connected-inl.h" +#include "../../operator/nn/pooling-inl.h" +#include "../../operator/softmax_output-inl.h" +#include "./tensorrt-inl.h" + +#if MXNET_USE_TENSORRT_ONNX_CHECKER +#include +#endif // MXNET_USE_TENSORRT_ONNX_CHECKER + +namespace mxnet { +namespace op { +namespace nnvm_to_onnx { + +op::TRTParam ConvertNnvmGraphToOnnx( + const nnvm::Graph& g, + std::unordered_map* const shared_buffer) { + op::TRTParam trt_param; + op::tensorrt::NameToIdx_t trt_input_map; + op::tensorrt::InferenceMap_t trt_output_map; + + const nnvm::IndexedGraph& ig = g.indexed_graph(); + const auto& storage_types = g.GetAttr("storage_type"); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shape_inputs = g.GetAttr("shape_inputs"); + + for (auto& e : storage_types) { + if (e != mshadow::kFloat32) { + LOG(FATAL) << "ONNX converter does not support types other than float32 " + "right now."; + } + } + + ModelProto model_proto; + // Need to determine IR versions and features to support + model_proto.set_ir_version(static_cast(2)); + GraphProto* graph_proto = model_proto.mutable_graph(); + + std::unordered_map placeholder_shapes = + GetPlaceholderShapes(shape_inputs, ig); + std::unordered_map output_lookup = GetOutputLookup(ig); + uint32_t current_input = 0; + + // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc. + for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) { + const IndexedGraph::Node& node = ig[node_idx]; + const nnvm::Node* source = node.source; + const NodeAttrs& attrs = source->attrs; + const Op* op = source->op(); + + std::string node_name = attrs.name; + // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a + // placeholder + if (source->is_variable()) { + // Is this a placeholder? + if (shared_buffer->count(node_name) == 0) { + // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky. + // Need to figure out how to properly fix it. + if (node_name.find("label") != std::string::npos) { + current_input++; + continue; + } + trt_input_map.emplace(node_name, current_input++); + ConvertPlaceholder(node_name, placeholder_shapes, graph_proto); + } else { + // If it's not a placeholder, then by exclusion it's a constant. + ConvertConstant(graph_proto, node_name, shared_buffer); + } // is_placeholder + } else { + // It's an op, rather than a "variable" (constant or placeholder) + NodeProto* node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + if (converter_map.count(op->name) == 0) { + LOG(FATAL) << "Conversion for node of type " << op->name << " (node " + << node_name << ") " + << " is not supported yet."; + } + // Find function ptr to a converter based on the op name, and invoke the converter. This + // looks unsafe because find may not succeed, but it does because we're in the operator + // logic after testing that this node name does not represent a variable. + converter_map.find(op->name)->second(node_proto, attrs, ig, node.inputs); + // Add all inputs to the current node (i.e. add graph edges) + for (const nnvm::IndexedGraph::NodeEntry& entry : node.inputs) { + std::string in_node_name = ig[entry.node_id].source->attrs.name; + // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less + // hacky way to do it than name matching. + if (in_node_name.find("label") != std::string::npos) { + continue; + } + node_proto->add_input(in_node_name); + } + // The node's output will have the same name as the node name. + node_proto->add_output(node_name); + // See if the current node is an output node + auto out_iter = output_lookup.find(node_name); + // We found an output + if (out_iter != output_lookup.end()) { + ConvertOutput(&trt_output_map, graph_proto, out_iter, node_name, g, + storage_types, dtypes); + } // output found + } // conversion function exists + } // loop over i from 0 to num_nodes + + model_proto.SerializeToString(&trt_param.serialized_onnx_graph); + common::Serialize(trt_input_map, + &trt_param.serialized_input_map); + common::Serialize(trt_output_map, + &trt_param.serialized_output_map); + +#if MXNET_USE_TENSORRT_ONNX_CHECKER + onnx::checker::check_model(model_proto); +#endif // MXNET_USE_TENSORRT_ONNX_CHECKER + + return trt_param; +} + +void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& conv_param = nnvm::get(attrs.parsed); + + node_proto->set_op_type("Conv"); + + const TShape kernel = conv_param.kernel; + const TShape stride = conv_param.stride; + const TShape dilate = conv_param.dilate; + const TShape pad = conv_param.pad; + const uint32_t num_group = conv_param.num_group; + // const bool no_bias = conv_param.no_bias; + const dmlc::optional layout = conv_param.layout; + + // kernel shape + AttributeProto* const kernel_shape = node_proto->add_attribute(); + kernel_shape->set_name("kernel_shape"); + kernel_shape->set_type(AttributeProto::INTS); + + for (const dim_t kval : kernel) { + kernel_shape->add_ints(static_cast(kval)); + } + + // pads + AttributeProto* const pads = node_proto->add_attribute(); + pads->set_name("pads"); + pads->set_type(AttributeProto::INTS); + + for (const dim_t kval : pad) { + pads->add_ints(static_cast(kval)); + pads->add_ints(static_cast(kval)); + } + + // dilations + AttributeProto* const dilations = node_proto->add_attribute(); + dilations->set_name("dilations"); + dilations->set_type(AttributeProto::INTS); + for (const dim_t kval : dilate) { + dilations->add_ints(static_cast(kval)); + } + + // strides + AttributeProto* const strides = node_proto->add_attribute(); + strides->set_name("strides"); + strides->set_type(AttributeProto::INTS); + for (const dim_t kval : stride) { + strides->add_ints(static_cast(kval)); + } + + // group + AttributeProto* const group = node_proto->add_attribute(); + group->set_name("group"); + group->set_type(AttributeProto::INT); + group->set_i(static_cast(num_group)); +} // end ConvertConvolution + +void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& pooling_param = nnvm::get(attrs.parsed); + + const TShape kernel = pooling_param.kernel; + const TShape stride = pooling_param.stride; + const TShape pad = pooling_param.pad; + const int pool_type = pooling_param.pool_type; + const bool global_pool = pooling_param.global_pool; + + if (global_pool) { + if (pool_type == 0) { + node_proto->set_op_type("GlobalMaxPool"); + } else { + node_proto->set_op_type("GlobalAveragePool"); + } + return; + } + + // kernel_shape + AttributeProto* const kernel_shape = node_proto->add_attribute(); + kernel_shape->set_name("kernel_shape"); + kernel_shape->set_type(AttributeProto::INTS); + for (int kval : kernel) { + kernel_shape->add_ints(static_cast(kval)); + } + + // pads + AttributeProto* const pads = node_proto->add_attribute(); + pads->set_name("pads"); + pads->set_type(AttributeProto::INTS); + for (int kval : pad) { + pads->add_ints(static_cast(kval)); + } + + // strides + AttributeProto* const strides = node_proto->add_attribute(); + strides->set_name("strides"); + strides->set_type(AttributeProto::INTS); + for (int kval : stride) { + strides->add_ints(static_cast(kval)); + } + + if (pool_type == 0) { + node_proto->set_op_type("MaxPool"); + } else { + node_proto->set_op_type("AveragePool"); + } // average pooling + // not global pooling +} // end ConvertPooling + +void ConvertActivation(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& act_param = nnvm::get(attrs.parsed); + std::string act_type; + switch (act_param.act_type) { + case op::activation::kReLU: + act_type = "Relu"; + break; + case op::activation::kSigmoid: + act_type = "Sigmoid"; + break; + case op::activation::kTanh: + act_type = "Tanh"; + break; + case op::activation::kSoftReLU: + // act_type = "SoftReLU"; + throw dmlc::Error("SoftReLU is not supported in ONNX"); + break; + default: + throw dmlc::Error("Activation of such type doesn't exist"); + } + + node_proto->set_op_type(act_type); +} + +void ConvertFullyConnected(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + const auto& act_param = nnvm::get(attrs.parsed); + if (act_param.no_bias) { + node_proto->set_op_type("MatMul"); + } else { + node_proto->set_op_type("Gemm"); + + AttributeProto* const alpha = node_proto->add_attribute(); + alpha->set_name("alpha"); + alpha->set_type(AttributeProto::FLOAT); + alpha->set_f(1.0f); + + AttributeProto* const beta = node_proto->add_attribute(); + beta->set_name("beta"); + beta->set_type(AttributeProto::FLOAT); + beta->set_f(1.0f); + + AttributeProto* const broadcast = node_proto->add_attribute(); + broadcast->set_name("broadcast"); + broadcast->set_type(AttributeProto::INT); + broadcast->set_i(1); + + AttributeProto* const transA = node_proto->add_attribute(); + transA->set_name("transA"); + transA->set_type(AttributeProto::INT); + transA->set_i(0); + + AttributeProto* const transB = node_proto->add_attribute(); + transB->set_name("transB"); + transB->set_type(AttributeProto::INT); + transB->set_i(1); + } +} + +void ConvertSoftmaxOutput(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Softmax"); + + // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its + // node params. This attribute is only relevant when the input is coerced to 2D, and in that + // case dimension 0 is assumed to be the batch dimension. + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); +} + +void ConvertFlatten(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Flatten"); + + // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its + // node params. This attribute is only relevant when the input is coerced to 2D, and in that + // case dimension 0 is assumed to be the batch dimension. + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); +} + +void ConvertBatchNorm(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("BatchNormalization"); + const auto& param = nnvm::get(attrs.parsed); + + AttributeProto* const epsilon = node_proto->add_attribute(); + epsilon->set_name("epsilon"); + epsilon->set_type(AttributeProto::FLOAT); + epsilon->set_f(static_cast(param.eps)); + + AttributeProto* const is_test = node_proto->add_attribute(); + is_test->set_name("is_test"); + is_test->set_type(AttributeProto::INT); + is_test->set_i(1); + + AttributeProto* const momentum = node_proto->add_attribute(); + momentum->set_name("momentum"); + momentum->set_type(AttributeProto::FLOAT); + momentum->set_f(param.momentum); + + AttributeProto* const spatial = node_proto->add_attribute(); + spatial->set_name("spatial"); + spatial->set_type(AttributeProto::INT); + spatial->set_i(1); + + AttributeProto* const consumed = node_proto->add_attribute(); + consumed->set_name("consumed_inputs"); + consumed->set_type(AttributeProto::INTS); + + for (int i = 0; i < 5; i++) { + int val = (i < 3) ? 0 : 1; + consumed->add_ints(static_cast(val)); + } +} + +void ConvertElementwiseAdd(NodeProto* node_proto, const NodeAttrs& /*attrs*/, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*inputs*/) { + node_proto->set_op_type("Add"); + AttributeProto* const axis = node_proto->add_attribute(); + axis->set_name("axis"); + axis->set_type(AttributeProto::INT); + axis->set_i(1); + + AttributeProto* const broadcast = node_proto->add_attribute(); + broadcast->set_name("broadcast"); + broadcast->set_type(AttributeProto::INT); + broadcast->set_i(0); // 1 +} + +std::unordered_map GetPlaceholderShapes( + const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) { + std::unordered_map placeholder_shapes; + for (uint32_t i = 0; i < shape_inputs.size(); ++i) { + std::string name = ig[ig.input_nodes()[i]].source->attrs.name; + TShape shp = shape_inputs[i]; + if (shp.ndim() > 0) { + placeholder_shapes.emplace(name, shp); + } + } + return placeholder_shapes; +} + +std::unordered_map GetOutputLookup( + const nnvm::IndexedGraph& ig) { + std::unordered_map output_lookup; + const std::vector& graph_outputs = + ig.outputs(); + for (uint32_t i = 0; i < graph_outputs.size(); ++i) { + const uint32_t id = graph_outputs[i].node_id; + const IndexedGraph::Node ig_node = ig[id]; + const nnvm::Node* const source = ig_node.source; + const std::string name = source->attrs.name; + output_lookup.emplace(name, i); + } + return output_lookup; +} + +void ConvertPlaceholder( + const std::string& node_name, + const std::unordered_map& placeholder_shapes, + GraphProto* const graph_proto) { + auto val_info_proto = graph_proto->add_input(); + auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type(); + auto shape_proto = type_proto->mutable_shape(); + + val_info_proto->set_name(node_name); + // Will support fp16, etc. in the near future + type_proto->set_elem_type(TensorProto_DataType_FLOAT); + auto entry_shape = placeholder_shapes.find(node_name)->second; + + for (const auto& elem : entry_shape) { + TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim(); + tsp_dim->set_dim_value(static_cast(elem)); + } +} + +void ConvertConstant( + GraphProto* const graph_proto, const std::string& node_name, + std::unordered_map* const shared_buffer) { + NodeProto* const node_proto = graph_proto->add_node(); + node_proto->set_name(node_name); + node_proto->add_output(node_name); + node_proto->set_op_type("Constant"); + + const NDArray nd = shared_buffer->find(node_name)->second; + const TBlob& blob = nd.data(); + const TShape shape = blob.shape_; + const int32_t size = shape.Size(); + + std::shared_ptr shared_data_ptr(new float[size]); + float* const data_ptr = shared_data_ptr.get(); + nd.SyncCopyToCPU(static_cast(data_ptr), size); + + AttributeProto* const tensor_attr = node_proto->add_attribute(); + tensor_attr->set_name("value"); + tensor_attr->set_type(AttributeProto::TENSOR); + + TensorProto* const tensor_proto = tensor_attr->mutable_t(); + tensor_proto->set_data_type(TensorProto_DataType_FLOAT); + for (auto& dim : shape) { + tensor_proto->add_dims(static_cast(dim)); + } + + for (int blob_idx = 0; blob_idx < size; ++blob_idx) { + tensor_proto->add_float_data(data_ptr[blob_idx]); + } +} + +void ConvertOutput( + op::tensorrt::InferenceMap_t* const trt_output_map, + GraphProto* const graph_proto, + const std::unordered_map::iterator& out_iter, + const std::string& node_name, const nnvm::Graph& g, + const StorageTypeVector& storage_types, const DTypeVector& dtypes) { + const nnvm::IndexedGraph& ig = g.indexed_graph(); + uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]); + TShape out_shape = g.GetAttr("shape")[out_idx]; + int storage_type = storage_types[out_idx]; + int dtype = dtypes[out_idx]; + + // This should work with fp16 as well + op::tensorrt::InferenceTuple_t out_tuple{out_iter->second, out_shape, storage_type, + dtype}; + + trt_output_map->emplace(node_name, out_tuple); + + auto graph_out = graph_proto->add_output(); + auto tensor_type = graph_out->mutable_type()->mutable_tensor_type(); + auto tensor_shape_proto = tensor_type->mutable_shape(); + graph_out->set_name(node_name); + + // Also support fp16. + tensor_type->set_elem_type(TensorProto_DataType_FLOAT); + + for (int64_t dim_shp : out_shape) { + TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim(); + tsp_dim->set_dim_value(static_cast(dim_shp)); + } +} + +} // namespace nnvm_to_onnx +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/contrib/tensorrt-inl.h b/src/operator/contrib/tensorrt-inl.h new file mode 100644 index 000000000000..be335ab1208f --- /dev/null +++ b/src/operator/contrib/tensorrt-inl.h @@ -0,0 +1,113 @@ +#ifndef MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ +#define MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file tensorrt-inl.h + * \brief TensorRT Operator + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../operator_common.h" +#include "../../common/utils.h" +#include "../../common/serialization.h" +#include "../../executor/exec_pass.h" +#include "../../executor/graph_executor.h" +#include "../../executor/onnx_to_tensorrt.h" + +namespace mxnet { +namespace op { + +using namespace nnvm; +using namespace ::onnx; +using int64 = ::google::protobuf::int64; + +namespace tensorrt { + enum class TypeIO { Inputs = 0, Outputs = 1 }; + using NameToIdx_t = std::map; + using InferenceTuple_t = std::tuple; + using InferenceMap_t = std::map; +} // namespace tensorrt + +using trt_name_to_idx = std::map; + +struct TRTParam : public dmlc::Parameter { + std::string serialized_onnx_graph; + std::string serialized_input_map; + std::string serialized_output_map; + tensorrt::NameToIdx_t input_map; + tensorrt::InferenceMap_t output_map; + ::onnx::ModelProto onnx_pb_graph; + + TRTParam() {} + + TRTParam(const ::onnx::ModelProto& onnx_graph, + const tensorrt::InferenceMap_t& input_map, + const tensorrt::NameToIdx_t& output_map) { + common::Serialize(input_map, &serialized_input_map); + common::Serialize(output_map, &serialized_output_map); + onnx_graph.SerializeToString(&serialized_onnx_graph); + } + +DMLC_DECLARE_PARAMETER(TRTParam) { + DMLC_DECLARE_FIELD(serialized_onnx_graph) + .describe("Serialized ONNX graph"); + DMLC_DECLARE_FIELD(serialized_input_map) + .describe("Map from inputs to topological order as input."); + DMLC_DECLARE_FIELD(serialized_output_map) + .describe("Map from outputs to order in g.outputs."); + } +}; + +struct TRTEngineParam { + nvinfer1::IExecutionContext* trt_executor; + std::vector > binding_map; +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_OPERATOR_CONTRIB_TENSORRT_INL_H_ diff --git a/src/operator/contrib/tensorrt.cc b/src/operator/contrib/tensorrt.cc new file mode 100644 index 000000000000..619fe1e2b8f4 --- /dev/null +++ b/src/operator/contrib/tensorrt.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cc + * \brief TensorRT operation registration + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./tensorrt-inl.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../common/serialization.h" +#include "../../common/utils.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(TRTParam); + +OpStatePtr GetPtrMapping(nvinfer1::ICudaEngine* trt_engine, + tensorrt::NameToIdx_t input_map, + tensorrt::NameToIdx_t output_map) { + TRTEngineParam param; + for (int b = 0; b < trt_engine->getNbBindings(); ++b) { + const std::string& binding_name = trt_engine->getBindingName(b); + if (trt_engine->bindingIsInput(b)) { + param.binding_map.emplace_back(input_map[binding_name], + tensorrt::TypeIO::Inputs); + } else { + param.binding_map.emplace_back(output_map[binding_name], + tensorrt::TypeIO::Outputs); + } + } + param.trt_executor = trt_engine->createExecutionContext(); + return OpStatePtr::Create(param); +} + +OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context /*ctx*/, + const std::vector& /*ishape*/, + const std::vector& /*itype*/) { + const auto& node_param = nnvm::get(attrs.parsed); + + ::onnx::ModelProto model_proto; + bool success = model_proto.ParseFromString(node_param.serialized_onnx_graph); + if (!success) { + LOG(FATAL) << "Problems parsing serialized ONNX model."; + } + auto graph = model_proto.graph(); + auto first_input_type = graph.input(0).type().tensor_type(); + auto dim_value = first_input_type.shape().dim(0).dim_value(); + auto batch_size = static_cast(dim_value); + // Need to set up max workspace size based on device properties + nvinfer1::ICudaEngine* const trt_engine = ::onnx_to_tensorrt::onnxToTrtCtx( + node_param.serialized_onnx_graph, batch_size, 1 << 30); + + tensorrt::NameToIdx_t output_map; + for (auto& el : node_param.output_map) { + output_map[el.first] = std::get<0>(el.second); + } + return GetPtrMapping(trt_engine, node_param.input_map, output_map); +} + +void TRTParamParser(nnvm::NodeAttrs* attrs) { + TRTParam param_; + + try { + param_.Init(attrs->dict); + common::Deserialize(¶m_.input_map, param_.serialized_input_map); + common::Deserialize(¶m_.output_map, param_.serialized_output_map); + param_.onnx_pb_graph.ParseFromString(param_.serialized_onnx_graph); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + + attrs->parsed = std::move(param_); +} + +inline bool TRTInferShape(const NodeAttrs& attrs, std::vector* /*in_shape*/, + std::vector* out_shape) { + const auto &node_param = nnvm::get(attrs.parsed); + for (auto& el : node_param.output_map) { + (*out_shape)[std::get<0>(el.second)] = std::get<1>(el.second); + } + return true; +} + +inline bool TRTInferStorageType(const NodeAttrs& /*attrs*/, const int /*dev_mask*/, + DispatchMode* dispatch_mode, + std::vector* /*in_storage_type*/, + std::vector* out_storage_type) { + return storage_type_assign(out_storage_type, mxnet::kDefaultStorage, + dispatch_mode, DispatchMode::kFCompute); +} + +inline bool TRTInferType(const NodeAttrs& attrs, std::vector* /*in_dtype*/, + std::vector* out_dtype) { + const auto& node_param = nnvm::get(attrs.parsed); + for (auto& el : node_param.output_map) { + (*out_dtype)[std::get<0>(el.second)] = std::get<3>(el.second); + } + return true; +} + +inline std::vector TRTListInputNames(const NodeAttrs& attrs) { + std::vector output; + const auto& node_param = nnvm::get(attrs.parsed); + output.resize(node_param.input_map.size()); + for (auto& el : node_param.input_map) { + output[el.second] = el.first; + } + return output; +} + +inline std::vector TRTListOutputNames(const NodeAttrs& attrs) { + std::vector output; + const auto& node_param = nnvm::get(attrs.parsed); + output.resize(node_param.output_map.size()); + for (auto& el : node_param.output_map) { + output[std::get<0>(el.second)] = el.first; + } + return output; +} + +NNVM_REGISTER_OP(_trt_op) + .describe(R"code(TRT operation (one engine) +)code" ADD_FILELINE) + .set_num_inputs([](const NodeAttrs& attrs) { + const auto& node_param = nnvm::get(attrs.parsed); + return node_param.input_map.size(); + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const auto& node_param = nnvm::get(attrs.parsed); + return node_param.output_map.size(); + }) + .set_attr_parser(TRTParamParser) + .set_attr("FInferShape", TRTInferShape) + .set_attr("FInferType", TRTInferType) + .set_attr("FListInputNames", TRTListInputNames) + .set_attr("FListOutputNames", TRTListOutputNames) + .set_attr("FCreateOpState", TRTCreateState) + .set_attr("FInferStorageType", TRTInferStorageType); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/operator/contrib/tensorrt.cu b/src/operator/contrib/tensorrt.cu new file mode 100644 index 000000000000..2fe8727b73e4 --- /dev/null +++ b/src/operator/contrib/tensorrt.cu @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file trt.cu + * \brief TensorRT GPU operation + * \author Marek Kolodziej, Clement Fuji Tsang +*/ + +#if MXNET_USE_TENSORRT + +#include "./tensorrt-inl.h" + +namespace mxnet { +namespace op { + +#define CHECK_CUDART(x) do { \ + cudaError_t res = (x); \ + if (res != cudaSuccess) { \ + fprintf(stderr, "CUDART: %s = %d (%s) at (%s:%d)\n", \ + #x, res, cudaGetErrorString(res), __FILE__, __LINE__); \ + exit(1); \ + } \ +} while (0) + +void TRTCompute(const OpStatePtr& state, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + + Stream* s = ctx.get_stream(); + cudaStream_t cuda_s = Stream::GetStream(s); + const auto& param = state.get_state(); + std::vector bindings; + bindings.reserve(param.binding_map.size()); + for (auto& p : param.binding_map) { + if (p.second == tensorrt::TypeIO::Inputs) { + bindings.emplace_back(inputs[p.first].dptr_); + } else { + bindings.emplace_back(outputs[p.first].dptr_); + } + } + + const int batch_size = static_cast(inputs[0].shape_[0]); + param.trt_executor->enqueue(batch_size, bindings.data(), cuda_s, nullptr); + CHECK_CUDART(cudaStreamSynchronize(cuda_s)); +} + +NNVM_REGISTER_OP(_trt_op) +.set_attr("FStatefulCompute", TRTCompute); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/tests/cpp/misc/serialization.cc b/tests/cpp/misc/serialization.cc new file mode 100644 index 000000000000..96f8b6c3a3a7 --- /dev/null +++ b/tests/cpp/misc/serialization.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include <../../../src/common/serialization.h> + +using namespace mxnet; +using namespace std; + +/* + * Test that used datastruct are properly serialized and deserialized + */ + +TEST(SerializerTest, InputMapCorrect) { + std::map input_map; + input_map.emplace("input_0", 2); + input_map.emplace("another_input", 0); + input_map.emplace("last_input", 1); + std::string serialized_data; + common::Serialize(input_map, &serialized_data); + std::map deserialized_input_map; + common::Deserialize(&deserialized_input_map, serialized_data); + ASSERT_EQ(input_map.size(), deserialized_input_map.size()); + for (auto& p : input_map) { + auto it = deserialized_input_map.find(p.first); + ASSERT_NE(it, deserialized_input_map.end()); + ASSERT_EQ(it->second, p.second); + } +} + +TEST(SerializerTest, OutputMapCorrect) { + std::map > output_map; + output_map.emplace("output_0", std::make_tuple(1, TShape({23, 12, 63, 432}), 0, 1)); + output_map.emplace("another_output", std::make_tuple(2, TShape({23, 123}), 14, -23)); + output_map.emplace("last_output", std::make_tuple(0, TShape({0}), -1, 0)); + std::string serialized_data; + common::Serialize(output_map, &serialized_data); + std::map > deserialized_output_map; + common::Deserialize(&deserialized_output_map, serialized_data); + ASSERT_EQ(output_map.size(), deserialized_output_map.size()); + for (auto& p : output_map) { + auto it = deserialized_output_map.find(p.first); + ASSERT_NE(it, deserialized_output_map.end()); + auto lhs = it->second; + auto rhs = p.second; + ASSERT_EQ(std::get<0>(lhs), std::get<0>(rhs)); + ASSERT_EQ(std::get<1>(lhs), std::get<1>(rhs)); + ASSERT_EQ(std::get<2>(lhs), std::get<2>(rhs)); + ASSERT_EQ(std::get<3>(lhs), std::get<3>(rhs)); + } +} + diff --git a/tests/python/tensorrt/common.py b/tests/python/tensorrt/common.py new file mode 100644 index 000000000000..c64649a19c41 --- /dev/null +++ b/tests/python/tensorrt/common.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +from ctypes.util import find_library + + +def check_tensorrt_installation(): + assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library" + + +def get_use_tensorrt(): + return int(os.environ.get("MXNET_USE_TENSORRT", 0)) + + +def set_use_tensorrt(status=False): + os.environ["MXNET_USE_TENSORRT"] = str(int(status)) + + +def merge_dicts(*dict_args): + """Merge arg_params and aux_params to populate shared_buffer""" + result = {} + for dictionary in dict_args: + result.update(dictionary) + return result + + +def get_fp16_infer_for_fp16_graph(): + return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0)) + + +def set_fp16_infer_for_fp16_graph(status=False): + os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status)) diff --git a/tests/python/tensorrt/lenet5_common.py b/tests/python/tensorrt/lenet5_common.py new file mode 100644 index 000000000000..347d6f3c11ba --- /dev/null +++ b/tests/python/tensorrt/lenet5_common.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import mxnet as mx +from common import * + +def get_iters(mnist, batch_size): + """Get MNIST iterators.""" + train_iter = mx.io.NDArrayIter(mnist['train_data'], + mnist['train_label'], + batch_size, + shuffle=True) + val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + all_test_labels = np.array(mnist['test_label']) + return train_iter, val_iter, test_iter, all_test_labels diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py new file mode 100644 index 000000000000..74de66620e88 --- /dev/null +++ b/tests/python/tensorrt/lenet5_train.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import numpy as np +import mxnet as mx +from lenet5_common import get_iters + +def lenet5(): + """LeNet-5 Symbol""" + #pylint: disable=no-member + data = mx.sym.Variable('data') + conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=20) + tanh1 = mx.sym.Activation(data=conv1, act_type="tanh") + pool1 = mx.sym.Pooling(data=tanh1, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # second conv + conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=50) + tanh2 = mx.sym.Activation(data=conv2, act_type="tanh") + pool2 = mx.sym.Pooling(data=tanh2, pool_type="max", + kernel=(2, 2), stride=(2, 2)) + # first fullc + flatten = mx.sym.Flatten(data=pool2) + fc1 = mx.sym.FullyConnected(data=flatten, num_hidden=500) + tanh3 = mx.sym.Activation(data=fc1, act_type="tanh") + # second fullc + fc2 = mx.sym.FullyConnected(data=tanh3, num_hidden=10) + # loss + lenet = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + #pylint: enable=no-member + return lenet + +def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter): + """train LeNet-5 model on MNIST data""" + ctx = mx.gpu(0) + lenet_model = mx.mod.Module(lenet5(), context=ctx) + + lenet_model.fit(train_iter, + eval_data=val_iter, + optimizer='sgd', + optimizer_params={'learning_rate': 0.1, 'momentum': 0.9}, + eval_metric='acc', + batch_end_callback=mx.callback.Speedometer(batch_size, 1), + num_epoch=num_epochs) + + # predict accuracy for lenet + acc = mx.metric.Accuracy() + lenet_model.score(test_iter, acc) + accuracy = acc.get()[1] + assert accuracy > 0.95, "LeNet-5 training accuracy on MNIST was too low" + return lenet_model + + +def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size): + """Run inference with either MXNet or TensorRT""" + + shared_buffer = merge_dicts(arg_params, aux_params) + if not get_use_tensorrt(): + shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) + executor = sym.simple_bind(ctx=mx.gpu(0), + data=(batch_size,) + mnist['test_data'].shape[1:], + softmax_label=(batch_size,), + shared_buffer=shared_buffer, + grad_req='null', + force_rebind=True) + + # Get this value from all_test_labels + # Also get classes from the dataset + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + + example_ct = 0 + + for idx, dbatch in enumerate(test_iter): + executor.arg_dict["data"][:] = dbatch.data[0] + executor.forward(is_train=False) + offset = idx*batch_size + extent = batch_size if num_ex - offset > batch_size else num_ex - offset + all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent] + example_ct += extent + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum() + + percentage = 100.0 * matches / example_ct + + return percentage + +if __name__ == '__main__': + + num_epochs = 10 + batch_size = 128 + model_name = 'lenet5' + model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") + model_file = '%s/%s-symbol.json' % (model_dir, model_name) + params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) + + if not (os.path.exists(model_file) and os.path.exists(params_file)): + mnist = mx.test_utils.get_mnist() + + _, _, _, all_test_labels = get_iters(mnist, batch_size) + + trained_lenet = train_lenet5(num_epochs, batch_size, + *get_iters(mnist, batch_size)[:-1]) + trained_lenet.save_checkpoint(model_name, num_epochs) diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py new file mode 100644 index 000000000000..37c3f5da2689 --- /dev/null +++ b/tests/python/tensorrt/test_cycle.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from common import * + +def detect_cycle_from(sym, visited, stack): + visited.add(sym.handle.value) + stack.add(sym.handle.value) + for s in s.get_children(): + if s.handle.value not in visited: + if detect_cycle_from(sym, visited, stack): + return True + elif s.handle.value in stack: + return True + stack.remove(sym.handle.value) + return False + +def has_no_cycle(sym): + visited = set() + stack = set() + all_nodes = sym.get_internals() + for s in all_nodes: + if s.handle.value in visited: + if detect_cycle_from(s, visited, stack): + return False + return True + +def test_simple_cycle(): + inp = mx.sym.Variable('input', shape=[1,10]) + A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A') + B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B') + D = mx.sym.sin(data=A, name='D') + C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C') + arg_params = { + 'I_weight': mx.nd.zeros([10,10]), + 'I_bias': mx.nd.zeros([10]), + 'A_weight': mx.nd.zeros([10,10]), + 'A_bias': mx.nd.zeros([10]), + 'B_weight': mx.nd.zeros([10,10]), + 'B_bias': mx.nd.zeros([10]), + } + set_use_tensorrt(True) + executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), + shared_buffer=arg_params, grad_req='null', force_rebind=True) + assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle" + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py new file mode 100644 index 000000000000..396b7dad4ab8 --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import numpy as np +import mxnet as mx +from common import * +from lenet5_common import get_iters + +def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size): + """Run inference with either MXNet or TensorRT""" + + shared_buffer = merge_dicts(arg_params, aux_params) + if not get_use_tensorrt(): + shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) + executor = sym.simple_bind(ctx=mx.gpu(0), + data=(batch_size,) + mnist['test_data'].shape[1:], + softmax_label=(batch_size,), + shared_buffer=shared_buffer, + grad_req='null', + force_rebind=True) + + # Get this value from all_test_labels + # Also get classes from the dataset + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) + + example_ct = 0 + + for idx, dbatch in enumerate(test_iter): + executor.arg_dict["data"][:] = dbatch.data[0] + executor.forward(is_train=False) + offset = idx*batch_size + extent = batch_size if num_ex - offset > batch_size else num_ex - offset + all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent] + example_ct += extent + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum() + + percentage = 100.0 * matches / example_ct + + return percentage + +def test_tensorrt_inference(): + """Run LeNet-5 inference comparison between MXNet and TensorRT.""" + check_tensorrt_installation() + mnist = mx.test_utils.get_mnist() + num_epochs = 10 + batch_size = 128 + model_name = 'lenet5' + model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") + model_file = '%s/%s-symbol.json' % (model_dir, model_name) + params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) + + _, _, _, all_test_labels = get_iters(mnist, batch_size) + + # Load serialized MXNet model (model-symbol.json + model-epoch.params) + sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) + + print("LeNet-5 test") + print("Running inference in MXNet") + set_use_tensorrt(False) + mx_pct = run_inference(sym, arg_params, aux_params, mnist, + all_test_labels, batch_size=batch_size) + + print("Running inference in MXNet-TensorRT") + set_use_tensorrt(True) + trt_pct = run_inference(sym, arg_params, aux_params, mnist, + all_test_labels, batch_size=batch_size) + + print("MXNet accuracy: %f" % mx_pct) + print("MXNet-TensorRT accuracy: %f" % trt_pct) + + assert abs(mx_pct - trt_pct) < 1e-2, \ + """Diff. between MXNet & TensorRT accuracy too high: + MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py b/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py new file mode 100644 index 000000000000..eda2db0cc8f5 --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import gluoncv +import mxnet as mx +import multiprocessing +import numpy as np +import os +import sys + +from mxnet.gluon.data.vision import transforms +from mxnet import gluon +from time import time + +def get_use_tensorrt(): + return int(os.environ.get("MXNET_USE_TENSORRT", 0)) + +def set_use_tensorrt(status=False): + os.environ["MXNET_USE_TENSORRT"] = str(int(status)) + +def get_fp16_infer_for_fp16_graph(): + return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0)) + +def set_fp16_infer_for_fp16_graph(status=False): + os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status)) + +#ssd_512_resnet50_v1_coco +def get_ssd_model(model_name='ssd_512_mobilenet1_0_coco', use_tensorrt=True, + ctx=mx.gpu(0), batch_size=32, fp16_for_fp32_graph=False): + + set_use_tensorrt(use_tensorrt) + set_fp16_infer_for_fp16_graph(fp16_for_fp32_graph) + net = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + anchors, class_preds, box_preds = net(data) + all_preds = mx.sym.concat(anchors, class_preds, box_preds, dim=2) + all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) + + if not get_use_tensorrt(): + all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in all_params.items()]) + + # class_preds + executor = all_preds.simple_bind(ctx=ctx, data=(batch_size, 3, 224, 224), grad_req='null', + shared_buffer=all_params, force_rebind=True) + return executor + + +def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, + ctx=mx.gpu(0), batch_size=128, fp16_for_fp32_graph=False, imagenet=False): + + set_use_tensorrt(use_tensorrt) + set_fp16_infer_for_fp16_graph(fp16_for_fp32_graph) + net = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + out = net(data) + + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + + all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) + + if not get_use_tensorrt(): + all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in all_params.items()]) + + if imagenet: + h, w = 224, 224 + else: + h, w = 32, 32 + + executor = softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), grad_req='null', + shared_buffer=all_params, force_rebind=True) + return executor + +def cifar10_infer(data_dir='./data', model_name='cifar_resnet56_v1', use_tensorrt=True, + ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): + + executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph, imagenet=False) + + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + + all_label_test = np.zeros(num_ex) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + ]) + + data_loader = lambda: gluon.data.DataLoader( + gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), + batch_size=batch_size, shuffle=False, num_workers=num_workers) + + val_data = data_loader() + + for idx, (data, label) in enumerate(val_data): + extent = data.shape[0] + offset = idx*batch_size + all_label_test[offset:offset+extent] = label.asnumpy() + + # warm-up, but don't use result + executor.arg_dict["data"][:extent, :] = data + executor.forward(is_train=False) + executor.outputs[0].wait_to_read() + + gc.collect() + + val_data = data_loader() + example_ct = 0 + + start = time() + + for idx, (data, label) in enumerate(val_data): + extent = data.shape[0] + executor.arg_dict["data"][:extent, :] = data + executor.forward(is_train=False) + preds = executor.outputs[0].asnumpy() + offset = idx*batch_size + all_preds[offset:offset+extent, :] = preds[:extent] + example_ct += extent + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum() + duration = time() - start + + return duration, 100.0 * matches / example_ct + +def ssd_infer(model_name='ssd_512_mobilenet1_0_voc', use_tensorrt=True, + ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): + + print("Running SSD inference with model: %s" % model_name) + executor = get_ssd_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph) + + start = None + num_runs = 50 + + for i in range(2): + data = np.random.randn(batch_size, 3, 224, 224) + executor.arg_dict["data"] = data + if i == 1: + start = time() + for runs in range(num_runs): + executor.forward(is_train = False) + executor.outputs[0].wait_to_read() +# all_preds = executor.outputs[0].asnumpy() +# anchors = all_preds[:, :, 0] +# class_preds = all_preds[:, :, 1] +# box_preds = all_preds[:, :, 2:] + + return time() - start + +def classif_imagenet_infer(model_name='ssd_512_mobilenet1_0_coco', use_tensorrt=True, + ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): + + executor = get_ssd_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph) + executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph, imagenet=False) + + start = None + num_runs = 2 + + for i in range(2): + data = np.random.randn(batch_size, 3, 224, 224) + executor.arg_dict["data"] = data + if i == 1: + start = time() + for runs in range(num_runs): + executor.forward(is_train = False) + executor.outputs[0].wait_to_read() + + return time() - start + + +def run_experiment_for(model_name, batch_size, num_workers, fp16_for_fp32_graph): + print("\n===========================================") + print("Model: %s" % model_name) + print("===========================================") + print("*** Running inference using pure MxNet ***\n") + mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, fp16_for_fp32_graph=fp16_for_fp32_graph, use_tensorrt=False) + print("\nMxNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct)) + + print("\n*** Running inference using MxNet + TensorRT ***\n") + trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, use_tensorrt=True) + print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct)) + speedup = mx_duration / trt_duration + print("TensorRT speed-up (not counting compilation): %.2fx" % speedup) + + acc_diff = abs(mx_pct - trt_pct) + print("Absolute accuracy difference: %f" % acc_diff) + return speedup, acc_diff + + +def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1, test_fp16=False): + + models = [ + 'cifar_resnet20_v1', + 'cifar_resnet56_v1', + 'cifar_resnet110_v1', + 'cifar_resnet20_v2', + 'cifar_resnet56_v2', + 'cifar_resnet110_v2', + 'cifar_wideresnet16_10', + 'cifar_wideresnet28_10', + 'cifar_wideresnet40_8', + 'cifar_resnext29_16x64d' + ] + + num_models = len(models) + + speedups = np.zeros(num_models, dtype=np.float32) + acc_diffs = np.zeros(num_models, dtype=np.float32) + + precisions = ["fp32"] + if test_fp16: + precisions.append("fp16") + + for precision in precisions: + + test_start = time() + + print("\n\nRunning inference in %s\n\n" % precision) + use_fp16 = True if precision == "fp16" else False + for idx, model in enumerate(models): + speedup, acc_diff = run_experiment_for(model, batch_size, num_workers, fp16_for_fp32_graph=use_fp16) + speedups[idx] = speedup + acc_diffs[idx] = acc_diff + assert acc_diff < tolerance, "Accuracy difference between MxNet and TensorRT > %.2f%% for model %s" % (tolerance, model) + + print("Perf and correctness checks run on the following models:") + print(models) + mean_speedup = np.mean(speedups) + std_speedup = np.std(speedups) + print("\nSpeedups:") + print(speedups) + print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) + print("Mean speedup: %.2f" % mean_speedup) + print("St. dev. of speedups: %.2f" % std_speedup) + print("\nAcc. differences: %s" % str(acc_diffs)) + + test_duration = time() - test_start + + print("Test duration: %.2f seconds" % test_duration) + +if __name__ == '__main__': + import nose + nose.runmodule() From 4855d8aa7ea068f5d3c4c796a78c1c11c9a14b4c Mon Sep 17 00:00:00 2001 From: Clement Fuji Tsang Date: Thu, 26 Jul 2018 00:16:54 +0200 Subject: [PATCH 02/43] correctly assign self._optimized_symbol in executor --- python/mxnet/executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index adb532ddf4fa..b71127ef0cf8 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -337,7 +337,8 @@ def optimized_symbol(self): if self._optimized_symbol is None: handle = SymbolHandle() check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(handle))) - return mx.sym.Symbol(handle=handle) + self._optimized_symbol = mx.sym.Symbol(handle=handle) + return self._optimized_symbol def copy_params_from(self, arg_params, aux_params=None, allow_extra_params=False): """Copy parameters from arg_params, aux_params into executor's internal array. From 419b294bf14b2badea610d160ee2384625b19e62 Mon Sep 17 00:00:00 2001 From: Clement Fuji Tsang Date: Thu, 26 Jul 2018 00:24:15 +0200 Subject: [PATCH 03/43] declare GetTrtCompatibleSubsets and ReplaceSubgraph only if MXNET_USE_TENSORRT --- src/executor/exec_pass.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 8a00b2e8cd24..111ee214cdda 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -209,6 +209,7 @@ bool DefaultStorageType(const nnvm::NodeAttrs& attrs, std::vector *iattr, std::vector *oattr); +#if MXNET_USE_TENSORRT /*! * \brief Replace subgraphs by TRT (forward only) */ @@ -218,6 +219,7 @@ Graph ReplaceSubgraph(Graph&& g, std::vector> GetTrtCompatibleSubsets(const Graph& g, std::unordered_map* const params_map); +#endif } // namespace exec } // namespace mxnet From 8d723c9226d2d59729ed3b20ffadbae79fdd2900 Mon Sep 17 00:00:00 2001 From: Clement Fuji Tsang Date: Thu, 26 Jul 2018 01:48:19 +0200 Subject: [PATCH 04/43] add comments in ReplaceSubgraph --- src/executor/tensorrt_pass.cc | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc index b2eeec1bf82c..9c33bf09fc01 100644 --- a/src/executor/tensorrt_pass.cc +++ b/src/executor/tensorrt_pass.cc @@ -466,6 +466,9 @@ Graph ReplaceSubgraph(Graph&& g, const auto sub_outputs_in_main = GetSubgraphOutputs(g, set_subgraph); subgraph.outputs = sub_outputs_in_main; + + // old2new will link raw pointer of the nodes in the graph to + // the corresponding shared_ptr of the nodes in the generated subgraph std::unordered_map old2new; std::deque stack; std::unordered_set visited; @@ -473,10 +476,14 @@ Graph ReplaceSubgraph(Graph&& g, old2new.reserve(reservation); visited.reserve(reservation); + // Create the shared_ptr using the same raw pointer don't really matter for (auto& n : set_subgraph) { old2new[n] = std::make_shared(*n); } + // To generate a subgraph an input have to be replace by data node (no op) + // and it have to be agnostic to the node from which it's an output + // (For exemple even if two inputs are two different outputs from the same node) nnvm::NodeEntryMap main_input_entry_to_sub; for (auto& e : GetSubgraphInterfaceNodes(g, set_subgraph)) { auto node = nnvm::Node::Create(); @@ -489,7 +496,8 @@ Graph ReplaceSubgraph(Graph&& g, e.node = old2new[e.node.get()]; stack.emplace_back(e.node.get()); } - + + // link all nodes in the subgraph to nodes in the subgraph instead of main graph while (!stack.empty()) { auto vertex = stack.front(); stack.pop_front(); @@ -506,7 +514,8 @@ Graph ReplaceSubgraph(Graph&& g, } } } - + + // Remove the control dependencies of the subgraph to nodes that are not in the subgraph DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) { std::remove_if(node->control_deps.begin(), node->control_deps.end(), @@ -530,6 +539,7 @@ Graph ReplaceSubgraph(Graph&& g, sub_input_entryid_to_main[sub_idx.entry_id(p.second)] = p.first; } + // Plug the nodes from the main graph as inputs of the trt node trtnodeptr->inputs.resize(main_input_entry_to_sub.size()); { uint32_t counter = 0; @@ -540,11 +550,13 @@ Graph ReplaceSubgraph(Graph&& g, } } } - + + nnvm::NodeEntryMap sub_outputs_in_main_to_pos; for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) { sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i; } + // Plug the trt node as inputs to the main graph nodes DFSVisit(g.outputs, [&sub_outputs_in_main_to_pos, &trtnodeptr](const nnvm::NodePtr& n) { for (auto& e : n->inputs) { auto it = sub_outputs_in_main_to_pos.find(e); From ca9462451b4c0f3b274c81fe68d8ce4eac7bcb99 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Wed, 25 Jul 2018 20:11:13 -0700 Subject: [PATCH 05/43] Addressing Haibin's code review points --- src/executor/graph_executor.cc | 7 ++++--- src/executor/graph_executor.h | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index d06ffc3cdd71..2b6f7e7475d9 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -46,6 +46,7 @@ namespace exec { GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); + use_tensorrt_ = dmlc::GetEnv("MXNET_USE_TENSORRT", false); need_grad_ = false; } @@ -790,7 +791,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, aux_state_vec->emplace_back(aux_nd); } else { auto it = shared_buffer->find(arg_name); - if ( it != shared_buffer->end() ) { + if (it != shared_buffer->end()) { aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top]))); } else { aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, @@ -855,7 +856,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // !shared_arg_names.count(arg_name) // model parameter, row_sparse ndarray sharing enabled bool enable_row_sparse_sharing = true; - if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + if (use_tensorrt_) { #if MXNET_USE_TENSORRT auto it = shared_buffer->find(arg_name); if (it != shared_buffer->end()) { @@ -1155,7 +1156,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g.GetAttr("storage_type")); } - if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { + if (use_tensorrt_) { #if MXNET_USE_TENSORRT auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); for (auto trt_group : trt_groups) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index e380344bf7ce..acf539a183dc 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -270,6 +270,8 @@ class GraphExecutor : public Executor { std::unordered_set cached_seg_opr_names_; // verbose logging bool log_verbose_ = false; + // use TensorRT optimization pass for inference + bool use_tensorrt_ = false; }; } // namespace exec From 75c864257c62138a27e4da01b6842739f1dde387 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Thu, 26 Jul 2018 08:48:09 -0700 Subject: [PATCH 06/43] Check that shared_buffer is not empty when USE_TENSORRT is set --- src/executor/graph_executor.cc | 7 ++++++- src/executor/tensorrt_pass.cc | 5 ----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 2b6f7e7475d9..718daad0f638 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -46,7 +46,7 @@ namespace exec { GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); - use_tensorrt_ = dmlc::GetEnv("MXNET_USE_TENSORRT", false); + use_tensorrt_ = dmlc::GetEnv("MXNET_USE_TENSORRT", false); need_grad_ = false; } @@ -1158,6 +1158,11 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (use_tensorrt_) { #if MXNET_USE_TENSORRT + if (shared_buffer->empty()) { + LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty." + << "Please provide weights and other parameters, such as " + << "BatchNorm moments, via the shared_buffer, during simple bind call."; + } auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); for (auto trt_group : trt_groups) { if (trt_group.size() > 1) { diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc index 9c33bf09fc01..b36cde088ef8 100644 --- a/src/executor/tensorrt_pass.cc +++ b/src/executor/tensorrt_pass.cc @@ -466,7 +466,6 @@ Graph ReplaceSubgraph(Graph&& g, const auto sub_outputs_in_main = GetSubgraphOutputs(g, set_subgraph); subgraph.outputs = sub_outputs_in_main; - // old2new will link raw pointer of the nodes in the graph to // the corresponding shared_ptr of the nodes in the generated subgraph std::unordered_map old2new; @@ -496,7 +495,6 @@ Graph ReplaceSubgraph(Graph&& g, e.node = old2new[e.node.get()]; stack.emplace_back(e.node.get()); } - // link all nodes in the subgraph to nodes in the subgraph instead of main graph while (!stack.empty()) { auto vertex = stack.front(); @@ -514,7 +512,6 @@ Graph ReplaceSubgraph(Graph&& g, } } } - // Remove the control dependencies of the subgraph to nodes that are not in the subgraph DFSVisit(subgraph.outputs, [&set_subgraph, &old2new](const nnvm::NodePtr& node) { std::remove_if(node->control_deps.begin(), @@ -550,8 +547,6 @@ Graph ReplaceSubgraph(Graph&& g, } } } - - nnvm::NodeEntryMap sub_outputs_in_main_to_pos; for (uint32_t i = 0; i < sub_outputs_in_main.size(); ++i) { sub_outputs_in_main_to_pos[sub_outputs_in_main[i]] = i; From 2a114665ce9342dbb808d9de63cda99fe209a415 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Thu, 26 Jul 2018 14:54:51 -0700 Subject: [PATCH 07/43] Added check that TensorRT binding is for inference only --- src/executor/graph_executor.cc | 9 +++++++-- src/executor/tensorrt_pass.cc | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 718daad0f638..6711d439d1ce 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -856,7 +856,8 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // !shared_arg_names.count(arg_name) // model parameter, row_sparse ndarray sharing enabled bool enable_row_sparse_sharing = true; - if (use_tensorrt_) { + + if (use_tensorrt_ && !need_grad_) { #if MXNET_USE_TENSORRT auto it = shared_buffer->find(arg_name); if (it != shared_buffer->end()) { @@ -870,6 +871,9 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk"; #endif } else { + if (use_tensorrt_) { + LOG(WARNING) << "USE_TENSORRT=1 but grads required. Running without TensorRT"; + } in_arg_vec->emplace_back(ReshapeOrCreate( arg_name, inferred_shape, inferred_dtype, inferred_stype, in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing)); @@ -1156,8 +1160,9 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g.GetAttr("storage_type")); } - if (use_tensorrt_) { + if (use_tensorrt_ && !need_grad_) { #if MXNET_USE_TENSORRT + // check that this graph is inference-only if (shared_buffer->empty()) { LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty." << "Please provide weights and other parameters, such as " diff --git a/src/executor/tensorrt_pass.cc b/src/executor/tensorrt_pass.cc index b36cde088ef8..b5fc8d15f7ac 100644 --- a/src/executor/tensorrt_pass.cc +++ b/src/executor/tensorrt_pass.cc @@ -262,7 +262,7 @@ using NodeEntrySet = std::unordered_set GetSubgraphOutputs(Graph g, +std::vector GetSubgraphNodeEntries(Graph g, std::unordered_set set_subgraph) { std::vector outputs; NodeEntrySet _outputs; @@ -464,7 +464,7 @@ Graph ReplaceSubgraph(Graph&& g, // Create MXNet subgraph Graph subgraph; - const auto sub_outputs_in_main = GetSubgraphOutputs(g, set_subgraph); + const auto sub_outputs_in_main = GetSubgraphNodeEntries(g, set_subgraph); subgraph.outputs = sub_outputs_in_main; // old2new will link raw pointer of the nodes in the graph to // the corresponding shared_ptr of the nodes in the generated subgraph From 190c9bfa551e0fb2499f41783fc21f0ae62dab4a Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Wed, 1 Aug 2018 19:58:21 -0700 Subject: [PATCH 08/43] Removed redundant decl. --- src/executor/exec_pass.h | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 111ee214cdda..8c483e9b2b8e 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -198,17 +198,6 @@ Graph InferStorageType(Graph&& graph, StorageTypeVector&& storage_type_inputs = StorageTypeVector(), const std::string& storage_type_attr_key = ""); -/*! \brief The default storage type inference function, which assigns all undefined - * storage types to kDefaultStorage. If all of input and output storage types - * are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise, - * DispatchMode::kFComputeFallback is assigned to dispatch_mode. - */ -bool DefaultStorageType(const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *iattr, - std::vector *oattr); - #if MXNET_USE_TENSORRT /*! * \brief Replace subgraphs by TRT (forward only) From d88ad8bd110eb3a2389b94723d1e6d2c929b4be0 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 15:33:34 +0200 Subject: [PATCH 09/43] WIP Refactored TRT integration and tests --- CMakeLists.txt | 38 +- include/mxnet/c_api.h | 38 ++ include/mxnet/executor.h | 37 +- python/mxnet/contrib/__init__.py | 1 + python/mxnet/contrib/tensorrt.py | 226 +++++++++ src/c_api/c_api_executor.cc | 354 +++++++++++++- src/executor/graph_executor.cc | 277 +++-------- src/executor/graph_executor.h | 96 ++-- src/executor/trt_graph_executor.cc | 448 ++++++++++++++++++ src/executor/trt_graph_executor.h | 91 ++++ tests/python/tensorrt/test_cvnets.py | 185 ++++++++ .../test_tensorrt_resnet_resnext_ssd.py | 260 ---------- 12 files changed, 1516 insertions(+), 535 deletions(-) create mode 100644 python/mxnet/contrib/tensorrt.py create mode 100644 src/executor/trt_graph_executor.cc create mode 100644 src/executor/trt_graph_executor.h create mode 100644 tests/python/tensorrt/test_cvnets.py delete mode 100644 tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e5744249dc0d..95e9f52890ad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,11 +188,47 @@ endif() if(USE_TENSORRT) message(STATUS "Using TensorRT") - include_directories(3rdparty/onnx-tensorrt/third_party/onnx/build/) + set(ONNX_PATH 3rdparty/onnx-tensorrt/third_party/onnx/build/) + set(ONNX_TRT_PATH 3rdparty/onnx-tensorrt/build/) + + include_directories(${ONNX_PATH}) include_directories(3rdparty/onnx-tensorrt/) include_directories(3rdparty/) add_definitions(-DMXNET_USE_TENSORRT=1) add_definitions(-DONNX_NAMESPACE=onnx) + list(APPEND mxnet_LINKER_LIBS libnvinfer.so) + + find_package(Protobuf REQUIRED) + list(APPEND mxnet_LINKER_LIBS ${PROTOBUF_LIBRARY}) + + message(STATUS "target path ${ONNX_PATH}") + find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED + PATHS ${ONNX_PATH} + DOC "Path to onnx library.") + message(STATUS "linking onnx ${ONNX_LIBRARY}") + list(APPEND mxnet_LINKER_LIBS ${ONNX_LIBRARY}) + + message(STATUS "target path ${ONNX_PATH}") + find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED + PATHS ${ONNX_PATH} + DOC "Path to onnx_proto library.") + message(STATUS "linking proto onnx ${ONNX_PROTO_LIBRARY}") + list(APPEND mxnet_LINKER_LIBS ${ONNX_PROTO_LIBRARY}) + + message(STATUS "target path ${ONNX_TRT_PATH}") + find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED + PATHS ${ONNX_TRT_PATH} + DOC "Path to onnx_proto library.") + message(STATUS "linking proto onnx ${ONNX_TRT_RUNTIME_LIBRARY}") + list(APPEND mxnet_LINKER_LIBS ${ONNX_TRT_RUNTIME_LIBRARY}) + + message(STATUS "target path ${ONNX_TRT_PATH}") + find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED + PATHS ${ONNX_TRT_PATH} + DOC "Path to onnx_proto library.") + message(STATUS "linking proto onnx ${ONNX_TRT_PARSER_LIBRARY}") + list(APPEND mxnet_LINKER_LIBS ${ONNX_TRT_PARSER_LIBRARY}) + endif() if(USE_MKLDNN) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 58b1b1b4dafe..319e6d8f255c 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1445,6 +1445,9 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, const int **aux_type_data, int *complete); +MXNET_DLL int MXTRTSymbolOptimize(SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle); + /*! * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 * \param sym_handle symbol to be converted @@ -1674,6 +1677,41 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, ExecutorHandle shared_exec_handle, ExecutorHandle* out); +MXNET_DLL int MXExecutorTensorRTBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); + /*! * \brief Return a new executor with the same symbol and shared memory, * but different input/output shapes. diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 11d3c5b0127b..5b074cedd67b 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -152,20 +152,39 @@ class Executor { static Executor* SimpleBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, - std::vector* in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - std::unordered_set* param_names, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& param_names, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, std::unordered_map* - shared_data_arrays = nullptr, + shared_data_arrays = nullptr, Executor* shared_exec = nullptr); + + static Executor* TensorRTBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + std::vector *in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* + shared_data_arrays = nullptr, + Executor* shared_exec = nullptr); + /*! * \brief the prototype of user-defined monitor callback */ diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index fbfd3469678b..606bb0ada54f 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -32,3 +32,4 @@ from . import io from . import quantization from . import quantization as quant +from . import tensorrt diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py new file mode 100644 index 000000000000..1cfec707d183 --- /dev/null +++ b/python/mxnet/contrib/tensorrt.py @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Module to enable the use of TensorRT optimized graphs.""" +# pylint: skip-file + +import ctypes + +from ..base import _LIB, c_array, c_array_buf, c_str_array, c_handle_array +from ..base import mx_uint, py_str, string_types +from ..base import NDArrayHandle, ExecutorHandle +from ..base import check_call, MXNetError +from array import array +from ..ndarray import _ndarray_cls +from ..executor import Executor + +import numpy as _numpy + + +def optimize_graph(sym, ctx, grad_req='write', type_dict=None, stype_dict=None, + group2ctx=None, shared_arg_names=None, shared_exec=None, + shared_buffer=None, **kwargs): + num_provided_arg_types = 0 + provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names + provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types + if type_dict is not None: + provided_arg_type_names = [] + provided_arg_type_data = [] + for k, v in type_dict.items(): + v = _numpy.dtype(v).type + if v in _DTYPE_NP_TO_MX: + provided_arg_type_names.append(k) + provided_arg_type_data.append(_DTYPE_NP_TO_MX[v]) + num_provided_arg_types = mx_uint(len(provided_arg_type_names)) + provided_arg_type_names = c_str_array(provided_arg_type_names) + provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data)) + + # storage types + num_provided_arg_stypes = 0 + # provided storage type argument names + provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() + provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if stype_dict is not None: + provided_arg_stype_names = [] + provided_arg_stype_data = [] + for k, v in stype_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + provided_arg_stype_names.append(k) + provided_arg_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v]) + num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) + provided_arg_stype_names = c_str_array(provided_arg_stype_names) + provided_arg_stype_data = c_array_buf(ctypes.c_int, array('i', provided_arg_stype_data)) + + provided_arg_shape_data = [] # shape data + # argument shape index in sdata, + # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg + provided_arg_shape_idx = [0] + provided_arg_shape_names = [] # provided argument names + for k, v in kwargs.items(): + # if k not in listed_arguments and k not in listed_aux_states: + # raise ValueError('arg name %s is not valid', k) + if isinstance(v, tuple): + provided_arg_shape_names.append(k) + provided_arg_shape_data.extend(v) + provided_arg_shape_idx.append(len(provided_arg_shape_data)) + + provided_req_type_list_len = 0 + provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() + provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() + if grad_req is not None: + if isinstance(grad_req, string_types): + # use provided_req_type_list_len = 0 to indicate this situation + provided_req_type_list_len = 0 + provided_grad_req_types = [grad_req] + elif isinstance(grad_req, list): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty list') + provided_grad_req_types = grad_req + provided_req_type_list_len = len(provided_grad_req_types) + elif isinstance(grad_req, dict): + if len(grad_req) == 0: + raise RuntimeError('grad_req in simple_bind cannot be an empty dict') + provided_grad_req_names = [] + provided_grad_req_types = [] + for k, v in grad_req.items(): + provided_grad_req_names.append(k) + provided_grad_req_types.append(v) + provided_grad_req_names = c_str_array(provided_grad_req_names) + provided_req_type_list_len = len(provided_grad_req_types) + provided_grad_req_types = c_str_array(provided_grad_req_types) + + num_ctx_map_keys = mx_uint(0) + ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() + ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() + ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() + if group2ctx is not None: + ctx_map_keys = [] + ctx_map_dev_types = [] + ctx_map_dev_ids = [] + for key, val in group2ctx.items(): + ctx_map_keys.append(key) + ctx_map_dev_types.append(val.device_typeid) + ctx_map_dev_ids.append(val.device_id) + num_ctx_map_keys = mx_uint(len(ctx_map_keys)) + ctx_map_keys = c_str_array(ctx_map_keys) + ctx_map_dev_types = c_array(ctypes.c_int, array('i', ctx_map_dev_types)) + ctx_map_dev_ids = c_array(ctypes.c_int, array('i', ctx_map_dev_ids)) + + # prepare param names + shared_arg_name_list = [] + if shared_arg_names is not None: + if not isinstance(shared_arg_names, list): + raise ValueError('shared_arg_names in simple_bind must be a list or None') + shared_arg_name_list = shared_arg_names + + # prepare shared_buffer + if shared_buffer is None: + shared_buffer_len = ctypes.c_int(-1) + shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() + else: + if not isinstance(shared_buffer, dict): + raise ValueError('shared_buffer in simple_bind must be dict or None') + buffer_names = shared_buffer.keys() + buffer_arrays = shared_buffer.values() + for v in buffer_arrays: + assert(v.stype == 'default'), \ + "shared_buffer is expected to only contain NDArrays with default storage" + shared_buffer_names = c_str_array(buffer_names) + shared_buffer_len = ctypes.c_int(len(buffer_arrays)) + shared_buffer_handles = c_handle_array(buffer_arrays) + updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() + updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() + + # prepare shared_exec_handle + shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() + + # prepare current executor handle + exe_handle = ExecutorHandle() + + # prepare current executor's in_args, arg_grads, and aux_states + num_in_args = ctypes.c_uint() + in_arg_handles = ctypes.POINTER(NDArrayHandle)() + arg_grad_handles = ctypes.POINTER(NDArrayHandle)() + num_aux_states = ctypes.c_uint() + aux_state_handles = ctypes.POINTER(NDArrayHandle)() + + try: + check_call(_LIB.MXExecutorTensorRTBind(sym.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_uint, + array('I', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('I', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + except MXNetError as e: + error_msg = "simple_bind error. Arguments:\n" + for k, v in kwargs.items(): + error_msg += "%s: %s\n" % (k, v) + error_msg += "%s" % e + raise RuntimeError(error_msg) + + # update shared_buffer + if shared_buffer is not None: + for i in range(shared_buffer_len.value): + k = py_str(updated_shared_buffer_names[i]) + v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i])) + shared_buffer[k] = v + + # create in_args, arg_grads, and aux_states for the current executor + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) + for i in range(num_in_args.value)] + grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) + if arg_grad_handles[i] is not None + else None for i in range(num_in_args.value)] + aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) + for i in range(num_aux_states.value)] + + executor = Executor(exe_handle, sym, ctx, grad_req, group2ctx) + executor.arg_arrays = arg_arrays + executor.grad_arrays = grad_arrays + executor.aux_arrays = aux_arrays + return executor \ No newline at end of file diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 723677ccc38f..30f692e63d8a 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -27,6 +27,7 @@ #include #include "./c_api_common.h" #include "../executor/graph_executor.h" +#include "../executor/trt_graph_executor.h" int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); @@ -441,9 +442,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector arg_grad_vec; std::vector aux_state_vec; - *out = Executor::SimpleBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, - &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map, - &grad_req_type_vec, &shared_arg_name_set, &in_arg_vec, + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); @@ -511,6 +512,345 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, API_END(); } +#if MXNET_USE_TENSORRT + +/*! + * \brief + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorTensorRTBind(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(symbol_handle); + + // get in_arg names + std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + + // attr_dict for setting up type_dict and arg/aux ctx + std::unordered_map> attr_dict; + if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) { + std::vector> attrs = + sym->ListAttrsRecursive(); + attr_dict.reserve(attrs.size()); + for (const auto& tp : attrs) { + attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); + } + } + + // setup arg_dtype_map + std::unordered_map arg_dtype_map; + if (nullptr == provided_arg_dtypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__dtype__")) { + arg_dtype_map[arg_name] = mshadow::kFloat32; + } + } + } else { // use user input type_dict + // create dtype map for in_args and aux_states + arg_dtype_map.reserve(num_provided_arg_dtypes); + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; + } + } + + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + + // create default ctx + Context ctx = Context::Create(static_cast(dev_type), dev_id); + // create ctx map + std::map ctx_map; + std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); + std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); + if (nullptr != g2c_keys) { // use user input group2ctx dict + for (mx_uint i = 0; i < num_g2c_keys; ++i) { + ctx_map[g2c_keys[i]] = Context::Create( + static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); + } + + // initialize in_arg_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(in_arg_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + in_arg_ctx_vec[i] = it3->second; + } + } + } + } + + // initialize aux_state_ctx_vec using group2ctx if there are any + for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { + const auto it1 = attr_dict.find(aux_state_names[i]); + if (it1 != attr_dict.end()) { + const auto it2 = it1->second.find("__ctx_group__"); + if (it2 != it1->second.end()) { + const auto it3 = ctx_map.find(it2->second); + if (it3 != ctx_map.end()) { + aux_state_ctx_vec[i] = it3->second; + } + } + } + } + } + + // create provided_grad_req_map + const std::map req_map = + {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; + std::unordered_map provided_grad_req_map; + std::string grad_req_type; + if (0 == provided_grad_req_list_len + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // string, grad_req='write' + CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) + << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' " + "are supported"; + grad_req_type = "string"; + } else if (provided_grad_req_list_len > 0 + && nullptr == provided_grad_req_names + && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] + grad_req_type = "list"; + CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) + << "The length of grad_req list does not match the number of input arguments in " + "simple_bind, expected " << in_arg_names.size() << ", provided " << + provided_grad_req_list_len; + } else if (provided_grad_req_list_len > 0 + && nullptr != provided_grad_req_names + && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': + // 'write'] + grad_req_type = "dict"; + provided_grad_req_map.reserve(provided_grad_req_list_len); + for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and \'add\' " + "are supported"; + provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; + } + } else { // grad_req is None + grad_req_type = "none"; + } + + // initialize arg_grad_ctx_vec and grad_req_type_vec + std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); + std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); + if ("none" != grad_req_type) { + for (size_t i = 0; i < in_arg_names.size(); ++i) { + OpReqType cur_req = kNullOp; + if ("string" == grad_req_type) { + cur_req = req_map.at(provided_grad_req_types[0]); + } else if ("list" == grad_req_type) { + CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) + << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " + "only \'null\', \'write\', and " + "\'add\' are supported"; + cur_req = req_map.at(provided_grad_req_types[i]); + } else if ("dict" == grad_req_type) { + const auto it = provided_grad_req_map.find(in_arg_names[i]); + if (it != provided_grad_req_map.end()) { + cur_req = req_map.at(it->second); + } + } + if (kNullOp != cur_req) { + arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; + grad_req_type_vec[i] = static_cast(cur_req); + } + } + } + + // create shape map for in_args and aux_states + std::unordered_map arg_shape_map(num_provided_arg_shapes); + for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { + auto p = arg_shape_map.emplace(provided_arg_shape_names[i], + TShape(provided_arg_shape_data+provided_arg_shape_idx[i], + provided_arg_shape_data+provided_arg_shape_idx[i+1])); + CHECK(p.second) << "Duplicate shapes are provided for argument " + << provided_arg_shape_names[i] << " in simple_bind"; + } + + // create para name set for sharing data array memory + std::unordered_set shared_arg_name_set(num_shared_arg_names); + for (mx_uint i = 0; i < num_shared_arg_names; ++i) { + shared_arg_name_set.insert(shared_arg_name_list[i]); + } + + // create shared_buffer_map + std::unordered_map shared_buffer_map; + bool use_shared_buffer = (*shared_buffer_len >= 0); + if (*shared_buffer_len > 0) { + // create shared_buffer_map + shared_buffer_map.reserve(*shared_buffer_len); + NDArray** shared_buffer_ptrs = + reinterpret_cast(shared_buffer_handle_list); + for (int i = 0; i < *shared_buffer_len; ++i) { + shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]); + } + } + + // create temporary place holders for the initialized NDArrays + // to be passed back to front end + std::vector in_arg_vec; + std::vector arg_grad_vec; + std::vector aux_state_vec; + + *out = Executor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, + &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map, + &grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + + // copy ndarray ptrs to ret->handles so that front end + // can access them + ret->ret_handles.clear(); + ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size() + +shared_buffer_map.size()); + size_t nd_idx = 0; + for (const auto& nd : in_arg_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (in_arg_vec.size() > 0) { + *num_in_args = in_arg_vec.size(); + *in_args = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : arg_grad_vec) { + if (nd.is_none()) { + ret->ret_handles.push_back(nullptr); + } else { + ret->ret_handles.push_back(new NDArray(nd)); + } + } + if (arg_grad_vec.size() > 0) { + *arg_grads = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + for (const auto& nd : aux_state_vec) { + if (nd.is_none()) { + LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(nd)); + } + if (aux_state_vec.size() > 0) { + *num_aux_states = aux_state_vec.size(); + *aux_states = &(ret->ret_handles[nd_idx]); + nd_idx = ret->ret_handles.size(); + } + + if (use_shared_buffer) { + ret->ret_vec_str.clear(); + ret->ret_vec_str.reserve(shared_buffer_map.size()); + ret->ret_vec_charp.clear(); + ret->ret_vec_charp.reserve(shared_buffer_map.size()); + for (const auto& kv : shared_buffer_map) { + if (kv.second.is_none()) { + LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; + } + ret->ret_handles.push_back(new NDArray(kv.second)); + ret->ret_vec_str.emplace_back(kv.first); + ret->ret_vec_charp.push_back(ret->ret_vec_str.back().c_str()); + } + *shared_buffer_len = shared_buffer_map.size(); + *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); + *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]); + } + + API_END(); +} + +#endif // MXNET_USE_TENSORRT + int MXExecutorReshape(int partial_shaping, int allow_up_sizing, int dev_type, @@ -598,16 +938,20 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete out); } +#if MXNET_USE_TENSORRT + int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); + auto s = new nnvm::Symbol(); API_BEGIN(); - exec::GraphExecutor *exec = static_cast(handle); + auto exec = static_cast(handle); *s = exec->GetOptimizedSymbol(); *out = s; API_END_HANDLE_ERROR(delete s); } +#endif // MXNET_USE_TENSORRT + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, void* callback_handle) { diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 6711d439d1ce..433e7aadce00 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -34,19 +34,11 @@ #include "../common/utils.h" #include "../common/exec_utils.h" -#if MXNET_USE_TENSORRT -#include -#include -#include "./onnx_to_tensorrt.h" -#include "../operator/contrib/tensorrt-inl.h" -#endif // MNET_USE_TENSORRT - namespace mxnet { namespace exec { GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); - use_tensorrt_ = dmlc::GetEnv("MXNET_USE_TENSORRT", false); need_grad_ = false; } @@ -64,7 +56,7 @@ GraphExecutor::~GraphExecutor() { } } -inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, +inline NDArray GraphExecutor::InitZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype) { // NDArray with default storage if (stype == kDefaultStorage) { @@ -76,7 +68,7 @@ inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, return NDArray(stype, shape, ctx, true, dtype); } -inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, +inline void GraphExecutor::EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype, std::vector *vec) { // NDArray with default storage @@ -320,15 +312,15 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, * \brief Assign context to the graph. * This is triggered by both simple_bind and bind flows. */ -static Graph AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - size_t num_forward_inputs, - size_t num_forward_outputs) { +Graph GraphExecutor::AssignContext(Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + size_t num_forward_inputs, + size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); const auto& mutable_nodes = idx.mutable_input_nodes(); // default use default context. @@ -445,7 +437,7 @@ static Graph AssignContext(Graph g, return g; } -static void HandleInferShapeError(const size_t num_forward_inputs, +void GraphExecutor::HandleInferShapeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes) { int cnt = 10; @@ -468,7 +460,7 @@ static void HandleInferShapeError(const size_t num_forward_inputs, << oss.str(); } -static void HandleInferTypeError(const size_t num_forward_inputs, +void GraphExecutor::HandleInferTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::DTypeVector& inferred_dtypes) { int cnt = 10; @@ -491,7 +483,7 @@ static void HandleInferTypeError(const size_t num_forward_inputs, << oss.str(); } -static void HandleInferStorageTypeError(const size_t num_forward_inputs, +void GraphExecutor::HandleInferStorageTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const StorageTypeVector& inferred_stypes) { int cnt = 10; @@ -696,13 +688,13 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, * Shareable storages include both default storage and row_sparse storage * if enable_row_sparse_sharing is `True`, otherwise default storage only. */ -static NDArray ReshapeOrCreate(const std::string& name, - const TShape& dest_arg_shape, - const int dest_arg_dtype, - const NDArrayStorageType dest_arg_stype, - const Context& ctx, - std::unordered_map* shared_buffer, - bool enable_row_sparse_sharing) { +NDArray GraphExecutor::ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, + const Context& ctx, + std::unordered_map* shared_buffer, + bool enable_row_sparse_sharing) { bool stype_shareable = dest_arg_stype == kDefaultStorage; if (enable_row_sparse_sharing) { stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; @@ -790,13 +782,8 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { - auto it = shared_buffer->find(arg_name); - if (it != shared_buffer->end()) { - aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top]))); - } else { - aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, - aux_state_ctxes[aux_top], inferred_dtype))); - } + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); @@ -856,29 +843,9 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } else { // !shared_arg_names.count(arg_name) // model parameter, row_sparse ndarray sharing enabled bool enable_row_sparse_sharing = true; - - if (use_tensorrt_ && !need_grad_) { - #if MXNET_USE_TENSORRT - auto it = shared_buffer->find(arg_name); - if (it != shared_buffer->end()) { - in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); - } else { - in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, - in_arg_ctxes[arg_top], inferred_dtype))); - } - #else - LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set, but MXNet wasn't " - << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk"; - #endif - } else { - if (use_tensorrt_) { - LOG(WARNING) << "USE_TENSORRT=1 but grads required. Running without TensorRT"; - } - in_arg_vec->emplace_back(ReshapeOrCreate( - arg_name, inferred_shape, inferred_dtype, inferred_stype, - in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing)); - } - + in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, + inferred_stype, in_arg_ctxes[arg_top], + shared_buffer, enable_row_sparse_sharing)); // gradient for model parameter, row_sparse ndarray sharing disabled if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); @@ -974,114 +941,6 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, this->InitOpSegs(); } -/*! - * \brief This function is triggered after each tensorrt subgraph replacement pass. - * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases) - * are absorbed into the TRT engine it also it rerun attributes inferences accordingly - * to the new topology. - */ -Graph GraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx, - const std::map &ctx_map, - std::vector *in_arg_ctxes, - std::vector *arg_grad_ctxes, - std::vector *aux_state_ctxes, - std::vector *grad_req_types, - std::unordered_map *arg_shape_map, - std::unordered_map *arg_dtype_map, - std::unordered_map *arg_stype_map, - std::unordered_map *params_map) { - std::unordered_set to_remove_params; - for (auto& el : *params_map) { - to_remove_params.insert(el.first); - } - - DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) { - to_remove_params.erase(n->attrs.name); - }); - - for (auto& el : to_remove_params) { - params_map->erase(el); - arg_shape_map->erase(el); - arg_dtype_map->erase(el); - arg_stype_map->erase(el); - } - const auto &idx = g.indexed_graph(); - num_forward_inputs_ = idx.input_nodes().size(); - in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); - arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); - grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); - aux_state_ctxes->resize(idx.mutable_input_nodes().size()); - - // create "device" and "context" attrs for the graph - g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, - *aux_state_ctxes, *grad_req_types, num_forward_inputs_, - num_forward_outputs_); - - // get number of nodes used in forward pass - num_forward_nodes_ = 0; - for (size_t i = 0; i < num_forward_outputs_; ++i) { - num_forward_nodes_ = std::max( - num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); - } - nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); - nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); - StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); - for (size_t i = 0; i < num_forward_inputs_; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const std::string &name = idx[nid].source->attrs.name; - auto it1 = arg_shape_map->find(name); - if (arg_shape_map->end() != it1) { - arg_shapes[i] = it1->second; - } - auto it2 = arg_dtype_map->find(name); - if (arg_dtype_map->end() != it2) { - arg_dtypes[i] = it2->second; - } - auto it3 = arg_stype_map->find(name); - if (arg_stype_map->end() != it3) { - arg_stypes[i] = it3->second; - } - } - g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - if (g.GetAttr("shape_num_unknown_nodes") != 0U) { - HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("shape")); - } - - g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); - if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { - HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("dtype")); - } - - g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); - - if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { - HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("storage_type")); - } - - return g; -} - -/*! - * \brief Return the "optimized" symbol contained in _graph. - * For optimization pass such as TensorRT pass - */ -nnvm::Symbol GraphExecutor::GetOptimizedSymbol() { - Symbol ret; - ret.outputs = std::vector(graph_.outputs.begin(), - graph_.outputs.begin() + num_forward_outputs_); - ret = ret.Copy(); - static const Op* trt_op = Op::Get("_trt_op"); - DFSVisit(ret.outputs, [](const nnvm::NodePtr n) { - if (n->op() == trt_op) { - n->attrs.dict.clear(); - } - }); - return ret; -} - /*! * \brief GraphExecutor initializer for simple bind flow in * which only certain input shapes and dtypes are provided by users. @@ -1099,23 +958,22 @@ nnvm::Symbol GraphExecutor::GetOptimizedSymbol() { void GraphExecutor::Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - std::vector* in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - std::unordered_set* shared_arg_names, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, std::unordered_map* shared_buffer, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { - symbol = symbol.Copy(); - nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, - *aux_state_ctxes, *grad_req_types); + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types); // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation // should do this differently. @@ -1129,16 +987,16 @@ void GraphExecutor::Init(nnvm::Symbol symbol, for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; - auto it1 = arg_shape_map->find(name); - if (arg_shape_map->end() != it1) { + auto it1 = arg_shape_map.find(name); + if (arg_shape_map.end() != it1) { arg_shapes[i] = it1->second; } - auto it2 = arg_dtype_map->find(name); - if (arg_dtype_map->end() != it2) { + auto it2 = arg_dtype_map.find(name); + if (arg_dtype_map.end() != it2) { arg_dtypes[i] = it2->second; } - auto it3 = arg_stype_map->find(name); - if (arg_stype_map->end() != it3) { + auto it3 = arg_stype_map.find(name); + if (arg_stype_map.end() != it3) { arg_stypes[i] = it3->second; } } @@ -1160,43 +1018,20 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g.GetAttr("storage_type")); } - if (use_tensorrt_ && !need_grad_) { - #if MXNET_USE_TENSORRT - // check that this graph is inference-only - if (shared_buffer->empty()) { - LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty." - << "Please provide weights and other parameters, such as " - << "BatchNorm moments, via the shared_buffer, during simple bind call."; - } - auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); - for (auto trt_group : trt_groups) { - if (trt_group.size() > 1) { - g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer); - g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, - aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map, - arg_stype_map, shared_buffer); - } - } - #else - LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set but MXNet wasn't " - << "built with TensorRT. Add USE_TENSORRT = 1 to config.mk"; - #endif - } - // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. if (nullptr == shared_buffer) { // regular simple bind - InitArguments(g.indexed_graph(), g.GetAttr("shape"), + InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), g.GetAttr("storage_type"), - *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, - *grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec - InitArguments(g.indexed_graph(), g.GetAttr("shape"), + InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), g.GetAttr("storage_type"), - *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, - *grad_req_types, *shared_arg_names, shared_exec, + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); } // The above code of shape and dtype inferences and argument @@ -1869,14 +1704,14 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, Executor *Executor::SimpleBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, - std::vector* in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - std::unordered_set* shared_arg_names, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, std::vector* in_args, std::vector* arg_grads, std::vector* aux_states, diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index acf539a183dc..070ab609d7b4 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -87,25 +87,26 @@ class GraphExecutor : public Executor { Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict = nnvm::NodeEntryMap()); + // initialize executor for simple bind void Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, - std::vector* in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - std::unordered_set* shared_arg_names, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::unordered_map& arg_shape_map, + const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec, std::unordered_map* shared_buffer = nullptr, Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict - = nnvm::NodeEntryMap()); + = nnvm::NodeEntryMap()); Executor* Reshape(const bool partial_shaping, const bool allow_up_sizing, @@ -117,8 +118,6 @@ class GraphExecutor : public Executor { std::vector* arg_grads, std::vector* aux_states) override; - nnvm::Symbol GetOptimizedSymbol(); - protected: friend class mxnet::Imperative; // Information about operational node @@ -163,22 +162,24 @@ class GraphExecutor : public Executor { std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec); + // Initialize in_args, arg_grads and aux_states with // shared_buffer and shared_exec - void InitArguments(const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes, - const nnvm::DTypeVector& inferred_dtypes, - const StorageTypeVector& inferred_stypes, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, - const Executor* shared_exec, - std::unordered_map* shared_buffer, - std::vector* in_arg_vec, - std::vector* arg_grad_vec, - std::vector* aux_state_vec); + virtual void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); + // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, const Context& default_ctx, @@ -214,21 +215,40 @@ class GraphExecutor : public Executor { void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); - + static void HandleInferShapeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes); + static void HandleInferTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes); + static void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const StorageTypeVector& inferred_stypes); + static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype); + static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec); + static NDArray ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, + const Context& ctx, + std::unordered_map* shared_buffer, + bool enable_row_sparse_sharing); + // Assign context to the graph. + static Graph AssignContext(Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + size_t num_forward_inputs, + size_t num_forward_outputs); // indicate whether there is a backward graph for gradients. bool need_grad_; - Graph ReinitGraph(Graph&& g, const Context &default_ctx, - const std::map &ctx_map, - std::vector* in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::vector* grad_req_types, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::unordered_map *params_map); - // internal graph nnvm::Graph graph_; // operator node @@ -270,8 +290,6 @@ class GraphExecutor : public Executor { std::unordered_set cached_seg_opr_names_; // verbose logging bool log_verbose_ = false; - // use TensorRT optimization pass for inference - bool use_tensorrt_ = false; }; } // namespace exec diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc new file mode 100644 index 000000000000..be9e3d7d765a --- /dev/null +++ b/src/executor/trt_graph_executor.cc @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_TENSORRT + +#include "trt_graph_executor.h" + +#include +#include +#include "./onnx_to_tensorrt.h" +#include "../operator/contrib/tensorrt-inl.h" + +namespace mxnet { +namespace exec { + + /*! + * \brief GraphExecutor initializer for simple bind flow in + * which only certain input shapes and dtypes are provided by users. + * The initializer uses these shapes and dtypes to perform + * shape and dtype inferences, and then create NDArrays + * to populate data entries of the graph. The created NDArrays + * for in_args, arg_grads and aux_states are passed to the + * front end to attach the created executor. + * In front end, if the simple_bind flow is trigger by + * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup + * and shared executor will be taken into account in creating + * NDArrays for in_args, arg_grads, and aux_states for resuing + * already allocated memory. + */ +void TrtGraphExecutor::Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer, + Executor* shared_exec, + const nnvm::NodeEntryMap& feed_dict) { + symbol = symbol.Copy(); + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types); + // The following code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize arg_shapes and arg_dtypes for shape and type inferences. + // It contains all in_args and aux_states' shapes and types in a certain order. + const nnvm::IndexedGraph& idx = g.indexed_graph(); + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string& name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { + arg_stypes[i] = it3->second; + } + } + g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } + + g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + + if (shared_buffer->empty()) { + LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty." + << "Please provide weights and other parameters, such as " + << "BatchNorm moments, via the shared_buffer, during simple bind call."; + } + auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); + for (auto trt_group : trt_groups) { + if (trt_group.size() > 1) { + g = ReplaceSubgraph(std::move(g), trt_group, shared_buffer); + g = ReinitGraph(std::move(g), default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, + aux_state_ctxes, grad_req_types, arg_shape_map, arg_dtype_map, + arg_stype_map, shared_buffer); + } + } + + + InitArguments(g.indexed_graph(), g.GetAttr("shape"), + g.GetAttr("dtype"), + g.GetAttr("storage_type"), + *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, + *grad_req_types, shared_arg_names, shared_exec, + shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); + + // The above code of shape and dtype inferences and argument + // initialization is for simple_bind only. Regular bind operation + // should do this differently. + + // Initialize the rest attributes of the graph. + // This function can be called by regular bind + // operation flow as well. + FinishInitGraph(symbol, g, shared_exec, feed_dict); +} +/*! + * \brief Initialize in_args, arg_grads, and aux_states + * and their data_entry_ of the executor using + * shared_buffer from DataParallelExecutorGroup + * and shared_exec if available. + */ +void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) { + bool use_tensorrt_ = true; + // initialize in_args, arg_grads, and aux_states and populate grad_store_ + data_entry_.resize(idx.num_node_entries()); + size_t arg_top = 0, aux_top = 0; + const auto& mutable_nodes = idx.mutable_input_nodes(); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; + const std::string& arg_name = idx[nid].source->attrs.name; + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec) { + const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); + CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage) + << "Non-default storage type detected when creating auxilliary NDArray. The allocated " + << "memory of shared_exec.aux_array cannot be resued for argument: " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_shape, aux_nd.shape()) + << "Inferred shape does not match shared_exec.aux_array's shape." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument: " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, aux_nd.dtype()) + << "Inferred dtype does not match shared_exec.aux_array's dtype." + " Therefore, the allocated memory for shared_exec.aux_array cannot" + " be resued for creating auxilliary NDArray of the argument: " + << arg_name << " for the current executor"; + aux_state_vec->emplace_back(aux_nd); + } else { + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + aux_state_vec->push_back(std::move(it->second.Copy(aux_state_ctxes[aux_top]))); + } else { + aux_state_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + aux_state_ctxes[aux_top], inferred_dtype))); + } + } // if (has_shared_exec) + data_entry_[eid] = aux_state_vec->back(); + aux_state_map_.emplace(arg_name, aux_state_vec->back()); + ++aux_top; + } else { // in_args and grad for in_args + if (shared_arg_names.count(arg_name)) { // model parameter + // model parameter + if (nullptr != shared_exec) { + const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); + auto arg_nd_stype = in_arg_nd.storage_type(); + // for model parameter, both default storage and row_sparse storage can be shared + bool shareable_arg_stype = inferred_stype == kDefaultStorage || + inferred_stype == kRowSparseStorage; + // try to reuse memory from shared_exec + CHECK(shareable_arg_stype) << "Inferred storage type " + << common::stype_string(inferred_stype) + << " does not support memory sharing with shared_exec.arg_array"; + CHECK_EQ(inferred_stype, arg_nd_stype) + << "Inferred stype does not match shared_exec.arg_array's stype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_shape, in_arg_nd.shape()) + << "Inferred shape does not match shared_exec.arg_array's shape" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + CHECK_EQ(inferred_dtype, in_arg_nd.dtype()) + << "Inferred dtype does not match shared_exec.arg_array's dtype" + " Therefore, the allocated memory for shared_exec.arg_array cannot" + " be resued for creating NDArray of the argument " + << arg_name << " for the current executor"; + in_arg_vec->emplace_back(in_arg_nd); + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec + arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); + } else { + // no need to reuse memory from shared_exec for gradient of non-default storage + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } + } else { // !shared_arg_names.count(arg_name) + // model parameter, row_sparse ndarray sharing enabled + bool enable_row_sparse_sharing = true; + + if (use_tensorrt_ && !need_grad_) { +#if MXNET_USE_TENSORRT + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); + } else { + in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + in_arg_ctxes[arg_top], inferred_dtype))); + } +#else + LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set, but MXNet wasn't " + << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk"; +#endif + } else { + if (use_tensorrt_) { + LOG(WARNING) << "USE_TENSORRT=1 but grads required. Running without TensorRT"; + } + in_arg_vec->emplace_back(ReshapeOrCreate( + arg_name, inferred_shape, inferred_dtype, inferred_stype, + in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing)); + } + + // gradient for model parameter, row_sparse ndarray sharing disabled + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + bool enable_row_sparse_sharing = false; + arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer, + enable_row_sparse_sharing)); + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } // if (kNullOp == grad_req_types[arg_top]) + } // if (shared_arg_names.count(arg_name)) + in_arg_map_.emplace(arg_name, in_arg_vec->back()); + if (!arg_grad_vec->back().is_none()) { + arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); + } + data_entry_[eid] = in_arg_vec->back(); + ++arg_top; + } + } +} + + + /*! + * \brief This function is triggered after each tensorrt subgraph replacement pass. + * Reset arguments of GraphExecutor::Init(...) as some variables (weights and biases) + * are absorbed into the TRT engine it also it reruns attributes inferences accordingly + * to the new topology. + */ +Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::vector *grad_req_types, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::unordered_map *params_map) { + std::unordered_set to_remove_params; + for (auto& el : *params_map) { + to_remove_params.insert(el.first); + } + + DFSVisit(g.outputs, [&to_remove_params](const nnvm::NodePtr n) { + to_remove_params.erase(n->attrs.name); + }); + + for (auto& el : to_remove_params) { + params_map->erase(el); + arg_shape_map->erase(el); + arg_dtype_map->erase(el); + arg_stype_map->erase(el); + } + const auto &idx = g.indexed_graph(); + num_forward_inputs_ = idx.input_nodes().size(); + in_arg_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + arg_grad_ctxes->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + grad_req_types->resize(num_forward_inputs_ - idx.mutable_input_nodes().size()); + aux_state_ctxes->resize(idx.mutable_input_nodes().size()); + + // create "device" and "context" attrs for the graph + g = AssignContext(g, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, + *aux_state_ctxes, *grad_req_types, num_forward_inputs_, + num_forward_outputs_); + + // get number of nodes used in forward pass + num_forward_nodes_ = 0; + for (size_t i = 0; i < num_forward_outputs_; ++i) { + num_forward_nodes_ = std::max( + num_forward_nodes_, static_cast(idx.outputs()[i].node_id + 1)); + } + nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); + nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + for (size_t i = 0; i < num_forward_inputs_; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const std::string &name = idx[nid].source->attrs.name; + auto it1 = arg_shape_map->find(name); + if (arg_shape_map->end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map->find(name); + if (arg_dtype_map->end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map->find(name); + if (arg_stype_map->end() != it3) { + arg_stypes[i] = it3->second; + } + } + g = InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + if (g.GetAttr("shape_num_unknown_nodes") != 0U) { + HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("shape")); + } + + g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { + HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("dtype")); + } + + g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + + return g; +} + + +/*! + * \brief Return the "optimized" symbol contained in _graph. + * For optimization pass such as TensorRT pass + */ +nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() { + Symbol ret; + ret.outputs = std::vector(graph_.outputs.begin(), + graph_.outputs.begin() + num_forward_outputs_); + ret = ret.Copy(); + static const Op* trt_op = Op::Get("_trt_op"); + DFSVisit(ret.outputs, [](const nnvm::NodePtr n) { + if (n->op() == trt_op) { + n->attrs.dict.clear(); + } + }); + return ret; +} + +} // namespace exec + +Executor *Executor::TensorRTBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + std::vector *in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* shared_buffer, + Executor* shared_exec) { + auto exec = new exec::TrtGraphExecutor(); + exec->Init(symbol, default_ctx, group2ctx, + in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, + arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_types, param_names, + in_args, arg_grads, aux_states, + shared_buffer, shared_exec); + return exec; +} + +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h new file mode 100644 index 000000000000..b201f0623462 --- /dev/null +++ b/src/executor/trt_graph_executor.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ +#define MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ + +#if MXNET_USE_TENSORRT + +#include +#include +#include + +#include "./graph_executor.h" + +namespace mxnet { + +namespace exec { + +class TrtGraphExecutor : public GraphExecutor { + public: + virtual void Init(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set& shared_arg_names, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec, + std::unordered_map* shared_buffer = nullptr, + Executor* shared_exec = nullptr, + const nnvm::NodeEntryMap& feed_dict + = nnvm::NodeEntryMap()); + + nnvm::Symbol GetOptimizedSymbol(); + + protected: + Graph ReinitGraph(Graph&& g, const Context &default_ctx, + const std::map &ctx_map, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::vector *grad_req_types, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::unordered_map *params_map); + + void InitArguments(const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec) override; +}; + +} // namespace exec +} // namespace mxnet + +#endif // MXNET_USE_TENSORRT + +#endif // MXNET_EXECUTOR_TRT_GRAPH_EXECUTOR_H_ diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py new file mode 100644 index 000000000000..751653930b63 --- /dev/null +++ b/tests/python/tensorrt/test_cvnets.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import gluoncv +import mxnet as mx +import numpy as np +import os + +from mxnet import gluon +from time import time + +from mxnet.gluon.data.vision import transforms + + +def get_use_tensorrt(): + return int(os.environ.get("MXNET_USE_TENSORRT", 0)) + + +def set_use_tensorrt(status=False): + os.environ["MXNET_USE_TENSORRT"] = str(int(status)) + + +def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, ctx=mx.gpu(0), + batch_size=128): + h, w = 32, 32 + set_use_tensorrt(use_tensorrt) + net = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + + if use_tensorrt: + out = net(data) + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) + executor = mx.contrib.tensorrt.optimize_graph(softmax, ctx=ctx, data=(batch_size, 3, h, w), + softmax_label=(batch_size,), grad_req='null', + shared_buffer=all_params, force_rebind=True) + else: + # Convert gluon model to Symbolic + net.hybridize() + net.forward(mx.ndarray.zeros((batch_size, 3, h, w))) + net.export(model_name) + symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0) + executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), + softmax_label=(batch_size,)) + executor.copy_params_from(arg_params, aux_params) + return executor + + +def cifar10_infer(data_dir='./data', model_name='cifar_resnet56_v1', use_tensorrt=True, + ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): + executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size) + + num_ex = 10000 + all_preds = np.zeros([num_ex, 10]) + + all_label_test = np.zeros(num_ex) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) + ]) + + data_loader = lambda: gluon.data.DataLoader( + gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), + batch_size=batch_size, shuffle=False, num_workers=num_workers) + + val_data = data_loader() + + for idx, (data, label) in enumerate(val_data): + # Skip last batch if it's undersized. + if data.shape[0] < batch_size: + continue + offset = idx * batch_size + all_label_test[offset:offset + batch_size] = label.asnumpy() + + # warm-up, but don't use result + executor.forward(is_train=False, data=data) + executor.outputs[0].wait_to_read() + + gc.collect() + val_data = data_loader() + example_ct = 0 + start = time() + + # if use_tensorrt: + for idx, (data, label) in enumerate(val_data): + # Skip last batch if it's undersized. + if data.shape[0] < batch_size: + continue + executor.forward(is_train=False, data=data) + preds = executor.outputs[0].asnumpy() + offset = idx * batch_size + all_preds[offset:offset + batch_size, :] = preds[:batch_size] + example_ct += batch_size + + all_preds = np.argmax(all_preds, axis=1) + matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum() + duration = time() - start + + return duration, 100.0 * matches / example_ct + + +def run_experiment_for(model_name, batch_size, num_workers): + print("\n===========================================") + print("Model: %s" % model_name) + print("===========================================") + print("*** Running inference using pure MXNet ***\n") + mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, use_tensorrt=False) + print("\nMXNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct)) + print("\n*** Running inference using MXNet + TensorRT ***\n") + trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, + num_workers=num_workers, use_tensorrt=True) + print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct)) + speedup = mx_duration / trt_duration + print("TensorRT speed-up (not counting compilation): %.2fx" % speedup) + + acc_diff = abs(mx_pct - trt_pct) + print("Absolute accuracy difference: %f" % acc_diff) + return speedup, acc_diff + + +def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1): + models = [ + 'cifar_resnet20_v1', + 'cifar_resnet56_v1', + 'cifar_resnet110_v1', + 'cifar_resnet20_v2', + 'cifar_resnet56_v2', + 'cifar_resnet110_v2', + 'cifar_wideresnet16_10', + 'cifar_wideresnet28_10', + 'cifar_wideresnet40_8', + 'cifar_resnext29_16x64d' + ] + + num_models = len(models) + + speedups = np.zeros(num_models, dtype=np.float32) + acc_diffs = np.zeros(num_models, dtype=np.float32) + + test_start = time() + + for idx, model in enumerate(models): + speedup, acc_diff = run_experiment_for(model, batch_size, num_workers) + speedups[idx] = speedup + acc_diffs[idx] = acc_diff + assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % ( + tolerance, model) + + print("Perf and correctness checks run on the following models:") + print(models) + mean_speedup = np.mean(speedups) + std_speedup = np.std(speedups) + print("\nSpeedups:") + print(speedups) + print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) + print("Mean speedup: %.2f" % mean_speedup) + print("St. dev. of speedups: %.2f" % std_speedup) + print("\nAcc. differences: %s" % str(acc_diffs)) + + test_duration = time() - test_start + + print("Test duration: %.2f seconds" % test_duration) + + +if __name__ == '__main__': + import nose + + nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py b/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py deleted file mode 100644 index eda2db0cc8f5..000000000000 --- a/tests/python/tensorrt/test_tensorrt_resnet_resnext_ssd.py +++ /dev/null @@ -1,260 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import gc -import gluoncv -import mxnet as mx -import multiprocessing -import numpy as np -import os -import sys - -from mxnet.gluon.data.vision import transforms -from mxnet import gluon -from time import time - -def get_use_tensorrt(): - return int(os.environ.get("MXNET_USE_TENSORRT", 0)) - -def set_use_tensorrt(status=False): - os.environ["MXNET_USE_TENSORRT"] = str(int(status)) - -def get_fp16_infer_for_fp16_graph(): - return int(os.environ.get("MXNET_TENSORRT_USE_FP16_FOR_FP32", 0)) - -def set_fp16_infer_for_fp16_graph(status=False): - os.environ["MXNET_TENSORRT_USE_FP16_FOR_FP32"] = str(int(status)) - -#ssd_512_resnet50_v1_coco -def get_ssd_model(model_name='ssd_512_mobilenet1_0_coco', use_tensorrt=True, - ctx=mx.gpu(0), batch_size=32, fp16_for_fp32_graph=False): - - set_use_tensorrt(use_tensorrt) - set_fp16_infer_for_fp16_graph(fp16_for_fp32_graph) - net = gluoncv.model_zoo.get_model(model_name, pretrained=True) - data = mx.sym.var('data') - anchors, class_preds, box_preds = net(data) - all_preds = mx.sym.concat(anchors, class_preds, box_preds, dim=2) - all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - - if not get_use_tensorrt(): - all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in all_params.items()]) - - # class_preds - executor = all_preds.simple_bind(ctx=ctx, data=(batch_size, 3, 224, 224), grad_req='null', - shared_buffer=all_params, force_rebind=True) - return executor - - -def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, - ctx=mx.gpu(0), batch_size=128, fp16_for_fp32_graph=False, imagenet=False): - - set_use_tensorrt(use_tensorrt) - set_fp16_infer_for_fp16_graph(fp16_for_fp32_graph) - net = gluoncv.model_zoo.get_model(model_name, pretrained=True) - data = mx.sym.var('data') - out = net(data) - - softmax = mx.sym.SoftmaxOutput(out, name='softmax') - - all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - - if not get_use_tensorrt(): - all_params = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in all_params.items()]) - - if imagenet: - h, w = 224, 224 - else: - h, w = 32, 32 - - executor = softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), grad_req='null', - shared_buffer=all_params, force_rebind=True) - return executor - -def cifar10_infer(data_dir='./data', model_name='cifar_resnet56_v1', use_tensorrt=True, - ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): - - executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph, imagenet=False) - - num_ex = 10000 - all_preds = np.zeros([num_ex, 10]) - - all_label_test = np.zeros(num_ex) - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) - ]) - - data_loader = lambda: gluon.data.DataLoader( - gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), - batch_size=batch_size, shuffle=False, num_workers=num_workers) - - val_data = data_loader() - - for idx, (data, label) in enumerate(val_data): - extent = data.shape[0] - offset = idx*batch_size - all_label_test[offset:offset+extent] = label.asnumpy() - - # warm-up, but don't use result - executor.arg_dict["data"][:extent, :] = data - executor.forward(is_train=False) - executor.outputs[0].wait_to_read() - - gc.collect() - - val_data = data_loader() - example_ct = 0 - - start = time() - - for idx, (data, label) in enumerate(val_data): - extent = data.shape[0] - executor.arg_dict["data"][:extent, :] = data - executor.forward(is_train=False) - preds = executor.outputs[0].asnumpy() - offset = idx*batch_size - all_preds[offset:offset+extent, :] = preds[:extent] - example_ct += extent - - all_preds = np.argmax(all_preds, axis=1) - matches = (all_preds[:example_ct] == all_label_test[:example_ct]).sum() - duration = time() - start - - return duration, 100.0 * matches / example_ct - -def ssd_infer(model_name='ssd_512_mobilenet1_0_voc', use_tensorrt=True, - ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): - - print("Running SSD inference with model: %s" % model_name) - executor = get_ssd_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph) - - start = None - num_runs = 50 - - for i in range(2): - data = np.random.randn(batch_size, 3, 224, 224) - executor.arg_dict["data"] = data - if i == 1: - start = time() - for runs in range(num_runs): - executor.forward(is_train = False) - executor.outputs[0].wait_to_read() -# all_preds = executor.outputs[0].asnumpy() -# anchors = all_preds[:, :, 0] -# class_preds = all_preds[:, :, 1] -# box_preds = all_preds[:, :, 2:] - - return time() - start - -def classif_imagenet_infer(model_name='ssd_512_mobilenet1_0_coco', use_tensorrt=True, - ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): - - executor = get_ssd_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph) - executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size, fp16_for_fp32_graph, imagenet=False) - - start = None - num_runs = 2 - - for i in range(2): - data = np.random.randn(batch_size, 3, 224, 224) - executor.arg_dict["data"] = data - if i == 1: - start = time() - for runs in range(num_runs): - executor.forward(is_train = False) - executor.outputs[0].wait_to_read() - - return time() - start - - -def run_experiment_for(model_name, batch_size, num_workers, fp16_for_fp32_graph): - print("\n===========================================") - print("Model: %s" % model_name) - print("===========================================") - print("*** Running inference using pure MxNet ***\n") - mx_duration, mx_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, - num_workers=num_workers, fp16_for_fp32_graph=fp16_for_fp32_graph, use_tensorrt=False) - print("\nMxNet: time elapsed: %.3fs, accuracy: %.2f%%" % (mx_duration, mx_pct)) - - print("\n*** Running inference using MxNet + TensorRT ***\n") - trt_duration, trt_pct = cifar10_infer(model_name=model_name, batch_size=batch_size, - num_workers=num_workers, use_tensorrt=True) - print("TensorRT: time elapsed: %.3fs, accuracy: %.2f%%" % (trt_duration, trt_pct)) - speedup = mx_duration / trt_duration - print("TensorRT speed-up (not counting compilation): %.2fx" % speedup) - - acc_diff = abs(mx_pct - trt_pct) - print("Absolute accuracy difference: %f" % acc_diff) - return speedup, acc_diff - - -def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1, test_fp16=False): - - models = [ - 'cifar_resnet20_v1', - 'cifar_resnet56_v1', - 'cifar_resnet110_v1', - 'cifar_resnet20_v2', - 'cifar_resnet56_v2', - 'cifar_resnet110_v2', - 'cifar_wideresnet16_10', - 'cifar_wideresnet28_10', - 'cifar_wideresnet40_8', - 'cifar_resnext29_16x64d' - ] - - num_models = len(models) - - speedups = np.zeros(num_models, dtype=np.float32) - acc_diffs = np.zeros(num_models, dtype=np.float32) - - precisions = ["fp32"] - if test_fp16: - precisions.append("fp16") - - for precision in precisions: - - test_start = time() - - print("\n\nRunning inference in %s\n\n" % precision) - use_fp16 = True if precision == "fp16" else False - for idx, model in enumerate(models): - speedup, acc_diff = run_experiment_for(model, batch_size, num_workers, fp16_for_fp32_graph=use_fp16) - speedups[idx] = speedup - acc_diffs[idx] = acc_diff - assert acc_diff < tolerance, "Accuracy difference between MxNet and TensorRT > %.2f%% for model %s" % (tolerance, model) - - print("Perf and correctness checks run on the following models:") - print(models) - mean_speedup = np.mean(speedups) - std_speedup = np.std(speedups) - print("\nSpeedups:") - print(speedups) - print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) - print("Mean speedup: %.2f" % mean_speedup) - print("St. dev. of speedups: %.2f" % std_speedup) - print("\nAcc. differences: %s" % str(acc_diffs)) - - test_duration = time() - test_start - - print("Test duration: %.2f seconds" % test_duration) - -if __name__ == '__main__': - import nose - nose.runmodule() From 87ebdce1bff7d19cb449c6403972b19037c15743 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 15:58:43 +0200 Subject: [PATCH 10/43] Add more build guards, remove unused code --- include/mxnet/c_api.h | 7 ++++--- src/c_api/c_api_executor.cc | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 319e6d8f255c..1254e335979f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1445,9 +1445,6 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, const int **aux_type_data, int *complete); -MXNET_DLL int MXTRTSymbolOptimize(SymbolHandle sym_handle, - SymbolHandle *ret_sym_handle); - /*! * \brief Convert a symbol into a quantized symbol where FP32 operators are replaced with INT8 * \param sym_handle symbol to be converted @@ -1677,6 +1674,8 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, ExecutorHandle shared_exec_handle, ExecutorHandle* out); +#if MXNET_USE_TENSORRT + MXNET_DLL int MXExecutorTensorRTBind(SymbolHandle symbol_handle, int dev_type, int dev_id, @@ -1712,6 +1711,8 @@ MXNET_DLL int MXExecutorTensorRTBind(SymbolHandle symbol_handle, ExecutorHandle shared_exec_handle, ExecutorHandle* out); +#endif // MXNET_USE_TENSORRT + /*! * \brief Return a new executor with the same symbol and shared memory, * but different input/output shapes. diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 30f692e63d8a..ac775bcf08e0 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -27,7 +27,9 @@ #include #include "./c_api_common.h" #include "../executor/graph_executor.h" +#if MXNET_USE_TENSORRT #include "../executor/trt_graph_executor.h" +#endif // MXNET_USE_TENSORRT int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { Executor *exec = static_cast(handle); From 83fa475037e096d6d00d03512ff350503dbad0e5 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 17:18:33 +0200 Subject: [PATCH 11/43] Remove ccache report --- ci/docker/runtime_functions.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 6e1bd5b0449c..3e19eaf70049 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -488,8 +488,6 @@ build_ubuntu_gpu_tensorrt() { ONNX_NAMESPACE=onnx \ CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\ -j$(nproc) - - report_ccache_usage } build_ubuntu_gpu_mkldnn() { From 4a3772f70d7ba6f824437fe2fb4922140eca4c37 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 17:19:21 +0200 Subject: [PATCH 12/43] Remove redundant const in declaration --- include/mxnet/c_api.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1254e335979f..a004fb4e01e6 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1679,24 +1679,24 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, MXNET_DLL int MXExecutorTensorRTBind(SymbolHandle symbol_handle, int dev_type, int dev_id, - const mx_uint num_g2c_keys, + mx_uint num_g2c_keys, const char** g2c_keys, const int* g2c_dev_types, const int* g2c_dev_ids, - const mx_uint provided_grad_req_list_len, + mx_uint provided_grad_req_list_len, const char** provided_grad_req_names, const char** provided_grad_req_types, - const mx_uint num_provided_arg_shapes, + mx_uint num_provided_arg_shapes, const char** provided_arg_shape_names, const mx_uint* provided_arg_shape_data, const mx_uint* provided_arg_shape_idx, - const mx_uint num_provided_arg_dtypes, + mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, - const mx_uint num_provided_arg_stypes, + mx_uint num_provided_arg_stypes, const char** provided_arg_stype_names, const int* provided_arg_stypes, - const mx_uint num_shared_arg_names, + mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, const char** shared_buffer_name_list, From f779537fd1d8872c4b4c982b1f409c06562f7045 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 17:40:34 +0200 Subject: [PATCH 13/43] Clean Cmake TRT files --- CMakeLists.txt | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 95e9f52890ad..8ff337ed159a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,39 +196,24 @@ if(USE_TENSORRT) include_directories(3rdparty/) add_definitions(-DMXNET_USE_TENSORRT=1) add_definitions(-DONNX_NAMESPACE=onnx) - list(APPEND mxnet_LINKER_LIBS libnvinfer.so) find_package(Protobuf REQUIRED) - list(APPEND mxnet_LINKER_LIBS ${PROTOBUF_LIBRARY}) - message(STATUS "target path ${ONNX_PATH}") find_library(ONNX_LIBRARY NAMES libonnx.so REQUIRED PATHS ${ONNX_PATH} DOC "Path to onnx library.") - message(STATUS "linking onnx ${ONNX_LIBRARY}") - list(APPEND mxnet_LINKER_LIBS ${ONNX_LIBRARY}) - - message(STATUS "target path ${ONNX_PATH}") find_library(ONNX_PROTO_LIBRARY NAMES libonnx_proto.so REQUIRED PATHS ${ONNX_PATH} DOC "Path to onnx_proto library.") - message(STATUS "linking proto onnx ${ONNX_PROTO_LIBRARY}") - list(APPEND mxnet_LINKER_LIBS ${ONNX_PROTO_LIBRARY}) - - message(STATUS "target path ${ONNX_TRT_PATH}") find_library(ONNX_TRT_RUNTIME_LIBRARY NAMES libnvonnxparser_runtime.so REQUIRED PATHS ${ONNX_TRT_PATH} DOC "Path to onnx_proto library.") - message(STATUS "linking proto onnx ${ONNX_TRT_RUNTIME_LIBRARY}") - list(APPEND mxnet_LINKER_LIBS ${ONNX_TRT_RUNTIME_LIBRARY}) - - message(STATUS "target path ${ONNX_TRT_PATH}") find_library(ONNX_TRT_PARSER_LIBRARY NAMES libnvonnxparser.so REQUIRED PATHS ${ONNX_TRT_PATH} DOC "Path to onnx_proto library.") - message(STATUS "linking proto onnx ${ONNX_TRT_PARSER_LIBRARY}") - list(APPEND mxnet_LINKER_LIBS ${ONNX_TRT_PARSER_LIBRARY}) + list(APPEND mxnet_LINKER_LIBS libnvinfer.so ${ONNX_TRT_PARSER_LIBRARY} ${ONNX_TRT_RUNTIME_LIBRARY} + ${ONNX_PROTO_LIBRARY} ${ONNX_LIBRARY} ${PROTOBUF_LIBRARY}) endif() if(USE_MKLDNN) From b0748ef2c43fc6472431e0428710a16f443736b7 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 17:52:15 +0200 Subject: [PATCH 14/43] Remove TensorRT env var usage We don't want to use environment variables with TensorRT yet, the logic being that we want to try and have as much fwd compatiblity as possible when working on an experimental feature. Were we to add env vars they would have to be gaurenteed to work in the future until a major version change. Moving the functionality to a contrib call reduces this risk. --- tests/python/tensorrt/common.py | 8 -------- tests/python/tensorrt/test_cvnets.py | 11 ----------- 2 files changed, 19 deletions(-) diff --git a/tests/python/tensorrt/common.py b/tests/python/tensorrt/common.py index c64649a19c41..eb599f69973c 100644 --- a/tests/python/tensorrt/common.py +++ b/tests/python/tensorrt/common.py @@ -23,14 +23,6 @@ def check_tensorrt_installation(): assert find_library('nvinfer') is not None, "Can't find the TensorRT shared library" -def get_use_tensorrt(): - return int(os.environ.get("MXNET_USE_TENSORRT", 0)) - - -def set_use_tensorrt(status=False): - os.environ["MXNET_USE_TENSORRT"] = str(int(status)) - - def merge_dicts(*dict_args): """Merge arg_params and aux_params to populate shared_buffer""" result = {} diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py index 751653930b63..451983b33e13 100644 --- a/tests/python/tensorrt/test_cvnets.py +++ b/tests/python/tensorrt/test_cvnets.py @@ -19,26 +19,15 @@ import gluoncv import mxnet as mx import numpy as np -import os from mxnet import gluon from time import time from mxnet.gluon.data.vision import transforms - -def get_use_tensorrt(): - return int(os.environ.get("MXNET_USE_TENSORRT", 0)) - - -def set_use_tensorrt(status=False): - os.environ["MXNET_USE_TENSORRT"] = str(int(status)) - - def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, ctx=mx.gpu(0), batch_size=128): h, w = 32, 32 - set_use_tensorrt(use_tensorrt) net = gluoncv.model_zoo.get_model(model_name, pretrained=True) data = mx.sym.var('data') From 35e1367fdf6ff24687a7d034c6a6d9ad6befc5de Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 18:00:15 +0200 Subject: [PATCH 15/43] Use contrib optimize_graph instaed of bind --- tests/python/tensorrt/test_cycle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py index 37c3f5da2689..adc5943e34d4 100644 --- a/tests/python/tensorrt/test_cycle.py +++ b/tests/python/tensorrt/test_cycle.py @@ -54,8 +54,8 @@ def test_simple_cycle(): 'B_weight': mx.nd.zeros([10,10]), 'B_bias': mx.nd.zeros([10]), } - set_use_tensorrt(True) - executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), + + executor = mx.contrib.tensorrt.optimize_graph(C, ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), shared_buffer=arg_params, grad_req='null', force_rebind=True) assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle" From 6338e4598c07594fd976cf6f00b97d48fc9e3d54 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 18:14:46 +0200 Subject: [PATCH 16/43] Clean up cycle detector --- tests/python/tensorrt/test_cycle.py | 56 +++++++++++++++-------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py index adc5943e34d4..a3319ea02f25 100644 --- a/tests/python/tensorrt/test_cycle.py +++ b/tests/python/tensorrt/test_cycle.py @@ -18,35 +18,38 @@ import mxnet as mx from common import * + def detect_cycle_from(sym, visited, stack): - visited.add(sym.handle.value) - stack.add(sym.handle.value) - for s in s.get_children(): - if s.handle.value not in visited: - if detect_cycle_from(sym, visited, stack): - return True - elif s.handle.value in stack: - return True - stack.remove(sym.handle.value) - return False + visited.add(sym.handle.value) + stack.add(sym.handle.value) + for s in sym.get_children(): + if s.handle.value not in visited: + if detect_cycle_from(sym, visited, stack): + return True + elif s.handle.value in stack: + return True + stack.remove(sym.handle.value) + return False + def has_no_cycle(sym): - visited = set() - stack = set() - all_nodes = sym.get_internals() - for s in all_nodes: - if s.handle.value in visited: - if detect_cycle_from(s, visited, stack): - return False - return True + visited = set() + stack = set() + all_nodes = sym.get_internals() + for s in all_nodes: + if s.handle.value in visited: + if detect_cycle_from(s, visited, stack): + return False + return True + def test_simple_cycle(): - inp = mx.sym.Variable('input', shape=[1,10]) - A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A') - B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B') - D = mx.sym.sin(data=A, name='D') - C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C') - arg_params = { + inp = mx.sym.Variable('input', shape=[1,10]) + A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A') + B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B') + D = mx.sym.sin(data=A, name='D') + C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C') + arg_params = { 'I_weight': mx.nd.zeros([10,10]), 'I_bias': mx.nd.zeros([10]), 'A_weight': mx.nd.zeros([10,10]), @@ -55,9 +58,10 @@ def test_simple_cycle(): 'B_bias': mx.nd.zeros([10]), } - executor = mx.contrib.tensorrt.optimize_graph(C, ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), + executor = mx.contrib.tensorrt.optimize_graph(C, ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), shared_buffer=arg_params, grad_req='null', force_rebind=True) - assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle" + assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle" + if __name__ == '__main__': import nose From 21d02393ac49ca0bedb5428b0784357a59f973a1 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 18:18:21 +0200 Subject: [PATCH 17/43] Convert lenet test to contrib optimize --- tests/python/tensorrt/test_tensorrt_lenet5.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index 396b7dad4ab8..e497faeabc7a 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -21,18 +21,25 @@ from common import * from lenet5_common import get_iters -def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size): +def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt): """Run inference with either MXNet or TensorRT""" shared_buffer = merge_dicts(arg_params, aux_params) - if not get_use_tensorrt(): + if not use_tensorrt: shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) - executor = sym.simple_bind(ctx=mx.gpu(0), + executor = sym.simple_bind(ctx=mx.gpu(0), data=(batch_size,) + mnist['test_data'].shape[1:], softmax_label=(batch_size,), shared_buffer=shared_buffer, grad_req='null', force_rebind=True) + else: + executor = mx.contrib.tensorrt.optimize_graph(sym, ctx=mx.gpu(0), + data=(batch_size,) + mnist['test_data'].shape[1:], + softmax_label=(batch_size,), + shared_buffer=shared_buffer, + grad_req='null', + force_rebind=True) # Get this value from all_test_labels # Also get classes from the dataset @@ -75,12 +82,10 @@ def test_tensorrt_inference(): print("LeNet-5 test") print("Running inference in MXNet") - set_use_tensorrt(False) mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size=batch_size) print("Running inference in MXNet-TensorRT") - set_use_tensorrt(True) trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size=batch_size) From f30dbefed260569542c15c04026e3035bb4d4479 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 18:33:01 +0200 Subject: [PATCH 18/43] Protect interface with trt build flag --- include/mxnet/executor.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 5b074cedd67b..02a22b15a77f 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -167,6 +167,8 @@ class Executor { shared_data_arrays = nullptr, Executor* shared_exec = nullptr); +#if MXNET_USE_TENSORRT + static Executor* TensorRTBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, @@ -185,6 +187,8 @@ class Executor { shared_data_arrays = nullptr, Executor* shared_exec = nullptr); +#endif // MXNET_USE_TENSORRT + /*! * \brief the prototype of user-defined monitor callback */ From eaba5932918ae2edb463dc042e67aed160767193 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 19:04:28 +0200 Subject: [PATCH 19/43] Fix whitespace issues --- src/executor/graph_executor.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 070ab609d7b4..13813eb5b8d2 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -87,7 +87,6 @@ class GraphExecutor : public Executor { Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict = nnvm::NodeEntryMap()); - // initialize executor for simple bind void Init(nnvm::Symbol symbol, const Context& default_ctx, @@ -106,7 +105,7 @@ class GraphExecutor : public Executor { std::unordered_map* shared_buffer = nullptr, Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict - = nnvm::NodeEntryMap()); + = nnvm::NodeEntryMap()); Executor* Reshape(const bool partial_shaping, const bool allow_up_sizing, @@ -162,7 +161,6 @@ class GraphExecutor : public Executor { std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec); - // Initialize in_args, arg_grads and aux_states with // shared_buffer and shared_exec virtual void InitArguments(const nnvm::IndexedGraph& idx, @@ -215,6 +213,7 @@ class GraphExecutor : public Executor { void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); + static void HandleInferShapeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes); From f870a3fca0f443e1ba32ee5ea1d401671e1120a2 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 19:35:13 +0200 Subject: [PATCH 20/43] Add another build guard to c_api --- include/mxnet/c_api.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a004fb4e01e6..84216d4e5cbe 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1754,12 +1754,17 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping, ExecutorHandle shared_exec, ExecutorHandle *out); + +#if MXNET_USE_TENSORRT + /*! * \brief get optimized graph from graph executor */ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out); +#endif // MXNET_USE_TENSORRT + /*! * \brief set a call back to notify the completion of operation */ From e40d6b31a851b1b48ee7af0f8e8fedf6c29210f7 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Wed, 8 Aug 2018 19:35:30 +0200 Subject: [PATCH 21/43] Move get_optimized_symbol to contrib area --- python/mxnet/contrib/tensorrt.py | 22 +++++++++++++++++++++- python/mxnet/executor.py | 15 --------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 1cfec707d183..7f60859d69d7 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -20,7 +20,7 @@ import ctypes -from ..base import _LIB, c_array, c_array_buf, c_str_array, c_handle_array +from ..base import _LIB, c_array, c_array_buf, c_str_array, c_handle_array, SymbolHandle from ..base import mx_uint, py_str, string_types from ..base import NDArrayHandle, ExecutorHandle from ..base import check_call, MXNetError @@ -31,6 +31,26 @@ import numpy as _numpy +def get_optimized_symbol(sym): + """Get optimized symbol. + + Parameters + ---------- + sym : nnvm::Symbol + The original symbol for which you wish to view the optimized form. + + Returns + ------- + symbol : nnvm::Symbol + The nnvm symbol optimized. + """ + if sym._optimized_symbol is None: + handle = SymbolHandle() + check_call(_LIB.MXExecutorGetOptimizedSymbol(sym.handle, ctypes.byref(handle))) + sym._optimized_symbol = sym.Symbol(handle=sym.handle) + return sym._optimized_symbol + + def optimize_graph(sym, ctx, grad_req='write', type_dict=None, stype_dict=None, group2ctx=None, shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index b71127ef0cf8..0d8a7835bd90 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -325,21 +325,6 @@ def output_dict(self): self._symbol.list_outputs(), self.outputs) return self._output_dict - @property - def optimized_symbol(self): - """Get optimized symbol. - - Returns - ------- - symbol : nnvm::Symbol - The nnvm symbol optimized. - """ - if self._optimized_symbol is None: - handle = SymbolHandle() - check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(handle))) - self._optimized_symbol = mx.sym.Symbol(handle=handle) - return self._optimized_symbol - def copy_params_from(self, arg_params, aux_params=None, allow_extra_params=False): """Copy parameters from arg_params, aux_params into executor's internal array. From 7fa6a4adc5116ab64755cbe128619cc684f9966c Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 12:14:12 +0200 Subject: [PATCH 22/43] Ignore gz files in test folder --- tests/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/.gitignore b/tests/.gitignore index d6459089c245..6100add5e450 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1 +1,2 @@ *_unittest +*.gz \ No newline at end of file From e777ab564dafbbf5a3bae18cb2aa7fb1eba32e46 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 12:14:37 +0200 Subject: [PATCH 23/43] Make trt optimization implicit --- Jenkinsfile | 20 +- include/mxnet/c_api.h | 39 -- include/mxnet/executor.h | 2 +- python/mxnet/contrib/tensorrt.py | 248 ++---------- python/mxnet/executor.py | 3 +- src/c_api/c_api_executor.cc | 368 ++---------------- src/executor/graph_executor.cc | 18 +- src/executor/graph_executor.h | 33 +- src/executor/trt_graph_executor.cc | 37 +- src/executor/trt_graph_executor.h | 3 +- tests/.gitignore | 2 +- tests/python/tensorrt/test_cvnets.py | 102 ++--- tests/python/tensorrt/test_cycle.py | 5 +- tests/python/tensorrt/test_tensorrt_lenet5.py | 75 ++-- 14 files changed, 209 insertions(+), 746 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 32b89561aed9..758e8e870eee 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -374,12 +374,12 @@ try { } }, 'TensorRT': { - node('mxnetlinux-cpu') { + node(NODE_LINUX_CPU) { ws('workspace/build-tensorrt') { timeout(time: max_time, unit: 'MINUTES') { - init_git() - docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false) - pack_lib('tensorrt', mx_tensorrt_lib) + utils.init_git() + utils.docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false) + utils.pack_lib('tensorrt', mx_tensorrt_lib) } } } @@ -753,16 +753,16 @@ try { } }, 'Python3: TensorRT GPU': { - node('mxnetlinux-gpu-p3') { + node(NODE_LINUX_GPU_P3) { ws('workspace/build-tensorrt') { timeout(time: max_time, unit: 'MINUTES') { try { - init_git() - unpack_lib('tensorrt', mx_tensorrt_lib) - docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) - publish_test_coverage() + utils.init_git() + utils.unpack_lib('tensorrt', mx_tensorrt_lib) + utils.docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true) + utils.publish_test_coverage() } finally { - collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') + utils.collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml') } } } diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 84216d4e5cbe..de9918a4cf96 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1674,45 +1674,6 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, ExecutorHandle shared_exec_handle, ExecutorHandle* out); -#if MXNET_USE_TENSORRT - -MXNET_DLL int MXExecutorTensorRTBind(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - mx_uint num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - mx_uint provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - mx_uint num_provided_arg_shapes, - const char** provided_arg_shape_names, - const mx_uint* provided_arg_shape_data, - const mx_uint* provided_arg_shape_idx, - mx_uint num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - mx_uint num_provided_arg_stypes, - const char** provided_arg_stype_names, - const int* provided_arg_stypes, - mx_uint num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - mx_uint* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - mx_uint* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out); - -#endif // MXNET_USE_TENSORRT - /*! * \brief Return a new executor with the same symbol and shared memory, * but different input/output shapes. diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 02a22b15a77f..412e12b01e20 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -164,7 +164,7 @@ class Executor { std::vector* arg_grads, std::vector* aux_states, std::unordered_map* - shared_data_arrays = nullptr, + shared_data_arrays = nullptr, Executor* shared_exec = nullptr); #if MXNET_USE_TENSORRT diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 7f60859d69d7..7720f717d2c5 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -16,231 +16,57 @@ # under the License. """ Module to enable the use of TensorRT optimized graphs.""" -# pylint: skip-file import ctypes +import logging +import os -from ..base import _LIB, c_array, c_array_buf, c_str_array, c_handle_array, SymbolHandle -from ..base import mx_uint, py_str, string_types -from ..base import NDArrayHandle, ExecutorHandle -from ..base import check_call, MXNetError -from array import array -from ..ndarray import _ndarray_cls -from ..executor import Executor +from mxnet.symbol import Symbol -import numpy as _numpy +from ..base import _LIB, SymbolHandle, MXNetError +from ..base import check_call -def get_optimized_symbol(sym): +def set_use_tensorrt(status=False): + """ + Set an environment variable which will enable or disable the use of TensorRT in the backend. + Note: this is useful for A/B testing purposes. + :param status: Boolean, true if TensorRT optimization should be applied, False for legacy + behaviour. + """ + os.environ["MXNET_USE_TENSORRT"] = str(int(status)) + + +def get_use_tensorrt(): + """ + Get an environment variable which describes if TensorRT is currently enabled in the backend. + Note: this is useful for A/B testing purposes. + :return: Boolean, true if TensorRT optimization should be applied, False for legacy + behaviour. + """ + return bool(int(os.environ.get("MXNET_USE_TENSORRT", 1)) == 1) + + +def get_optimized_symbol(executor): """Get optimized symbol. Parameters ---------- - sym : nnvm::Symbol - The original symbol for which you wish to view the optimized form. + executor : + An executor for which you want to see an optimized symbol. Getting an optimized symbol + is useful to compare and verify the work TensorRT has done against a legacy behaviour. Returns ------- symbol : nnvm::Symbol The nnvm symbol optimized. """ - if sym._optimized_symbol is None: - handle = SymbolHandle() - check_call(_LIB.MXExecutorGetOptimizedSymbol(sym.handle, ctypes.byref(handle))) - sym._optimized_symbol = sym.Symbol(handle=sym.handle) - return sym._optimized_symbol - - -def optimize_graph(sym, ctx, grad_req='write', type_dict=None, stype_dict=None, - group2ctx=None, shared_arg_names=None, shared_exec=None, - shared_buffer=None, **kwargs): - num_provided_arg_types = 0 - provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names - provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types - if type_dict is not None: - provided_arg_type_names = [] - provided_arg_type_data = [] - for k, v in type_dict.items(): - v = _numpy.dtype(v).type - if v in _DTYPE_NP_TO_MX: - provided_arg_type_names.append(k) - provided_arg_type_data.append(_DTYPE_NP_TO_MX[v]) - num_provided_arg_types = mx_uint(len(provided_arg_type_names)) - provided_arg_type_names = c_str_array(provided_arg_type_names) - provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data)) - - # storage types - num_provided_arg_stypes = 0 - # provided storage type argument names - provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() - provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types - if stype_dict is not None: - provided_arg_stype_names = [] - provided_arg_stype_data = [] - for k, v in stype_dict.items(): - if v in _STORAGE_TYPE_STR_TO_ID: - provided_arg_stype_names.append(k) - provided_arg_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v]) - num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) - provided_arg_stype_names = c_str_array(provided_arg_stype_names) - provided_arg_stype_data = c_array_buf(ctypes.c_int, array('i', provided_arg_stype_data)) - - provided_arg_shape_data = [] # shape data - # argument shape index in sdata, - # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg - provided_arg_shape_idx = [0] - provided_arg_shape_names = [] # provided argument names - for k, v in kwargs.items(): - # if k not in listed_arguments and k not in listed_aux_states: - # raise ValueError('arg name %s is not valid', k) - if isinstance(v, tuple): - provided_arg_shape_names.append(k) - provided_arg_shape_data.extend(v) - provided_arg_shape_idx.append(len(provided_arg_shape_data)) - - provided_req_type_list_len = 0 - provided_grad_req_types = ctypes.POINTER(ctypes.c_char_p)() - provided_grad_req_names = ctypes.POINTER(ctypes.c_char_p)() - if grad_req is not None: - if isinstance(grad_req, string_types): - # use provided_req_type_list_len = 0 to indicate this situation - provided_req_type_list_len = 0 - provided_grad_req_types = [grad_req] - elif isinstance(grad_req, list): - if len(grad_req) == 0: - raise RuntimeError('grad_req in simple_bind cannot be an empty list') - provided_grad_req_types = grad_req - provided_req_type_list_len = len(provided_grad_req_types) - elif isinstance(grad_req, dict): - if len(grad_req) == 0: - raise RuntimeError('grad_req in simple_bind cannot be an empty dict') - provided_grad_req_names = [] - provided_grad_req_types = [] - for k, v in grad_req.items(): - provided_grad_req_names.append(k) - provided_grad_req_types.append(v) - provided_grad_req_names = c_str_array(provided_grad_req_names) - provided_req_type_list_len = len(provided_grad_req_types) - provided_grad_req_types = c_str_array(provided_grad_req_types) - - num_ctx_map_keys = mx_uint(0) - ctx_map_keys = ctypes.POINTER(ctypes.c_char_p)() - ctx_map_dev_types = ctypes.POINTER(ctypes.c_int)() - ctx_map_dev_ids = ctypes.POINTER(ctypes.c_int)() - if group2ctx is not None: - ctx_map_keys = [] - ctx_map_dev_types = [] - ctx_map_dev_ids = [] - for key, val in group2ctx.items(): - ctx_map_keys.append(key) - ctx_map_dev_types.append(val.device_typeid) - ctx_map_dev_ids.append(val.device_id) - num_ctx_map_keys = mx_uint(len(ctx_map_keys)) - ctx_map_keys = c_str_array(ctx_map_keys) - ctx_map_dev_types = c_array(ctypes.c_int, array('i', ctx_map_dev_types)) - ctx_map_dev_ids = c_array(ctypes.c_int, array('i', ctx_map_dev_ids)) - - # prepare param names - shared_arg_name_list = [] - if shared_arg_names is not None: - if not isinstance(shared_arg_names, list): - raise ValueError('shared_arg_names in simple_bind must be a list or None') - shared_arg_name_list = shared_arg_names - - # prepare shared_buffer - if shared_buffer is None: - shared_buffer_len = ctypes.c_int(-1) - shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() - shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() - else: - if not isinstance(shared_buffer, dict): - raise ValueError('shared_buffer in simple_bind must be dict or None') - buffer_names = shared_buffer.keys() - buffer_arrays = shared_buffer.values() - for v in buffer_arrays: - assert(v.stype == 'default'), \ - "shared_buffer is expected to only contain NDArrays with default storage" - shared_buffer_names = c_str_array(buffer_names) - shared_buffer_len = ctypes.c_int(len(buffer_arrays)) - shared_buffer_handles = c_handle_array(buffer_arrays) - updated_shared_buffer_names = ctypes.POINTER(ctypes.c_char_p)() - updated_shared_buffer_handles = ctypes.POINTER(NDArrayHandle)() - - # prepare shared_exec_handle - shared_exec_handle = shared_exec.handle if shared_exec is not None else ExecutorHandle() - - # prepare current executor handle - exe_handle = ExecutorHandle() - - # prepare current executor's in_args, arg_grads, and aux_states - num_in_args = ctypes.c_uint() - in_arg_handles = ctypes.POINTER(NDArrayHandle)() - arg_grad_handles = ctypes.POINTER(NDArrayHandle)() - num_aux_states = ctypes.c_uint() - aux_state_handles = ctypes.POINTER(NDArrayHandle)() - + handle = SymbolHandle() try: - check_call(_LIB.MXExecutorTensorRTBind(sym.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_uint, - array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('I', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) - except MXNetError as e: - error_msg = "simple_bind error. Arguments:\n" - for k, v in kwargs.items(): - error_msg += "%s: %s\n" % (k, v) - error_msg += "%s" % e - raise RuntimeError(error_msg) - - # update shared_buffer - if shared_buffer is not None: - for i in range(shared_buffer_len.value): - k = py_str(updated_shared_buffer_names[i]) - v = NDArray(NDArrayHandle(updated_shared_buffer_handles[i])) - shared_buffer[k] = v - - # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) - for i in range(num_in_args.value)] - grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) - if arg_grad_handles[i] is not None - else None for i in range(num_in_args.value)] - aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) - for i in range(num_aux_states.value)] - - executor = Executor(exe_handle, sym, ctx, grad_req, group2ctx) - executor.arg_arrays = arg_arrays - executor.grad_arrays = grad_arrays - executor.aux_arrays = aux_arrays - return executor \ No newline at end of file + check_call(_LIB.MXExecutorGetOptimizedSymbol(executor.handle, ctypes.byref(handle))) + result = Symbol(handle=handle) + return result + except MXNetError: + logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure' + 'build was compiled with MXNET_USE_TENSORRT enabled.') + raise diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 0d8a7835bd90..fcd5406236e9 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -24,9 +24,8 @@ import ctypes import copy import numpy as np -import mxnet as mx from .base import _LIB -from .base import mx_uint, NDArrayHandle, ExecutorHandle, SymbolHandle, py_str +from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str from .base import check_call, c_handle_array, c_array_buf, c_str_array from .ndarray import NDArray from .ndarray import _ndarray_cls diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ac775bcf08e0..64390c46b067 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -443,13 +443,28 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector in_arg_vec; std::vector arg_grad_vec; std::vector aux_state_vec; - - *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, - grad_req_type_vec, shared_arg_name_set, &in_arg_vec, - &arg_grad_vec, &aux_state_vec, - use_shared_buffer ? &shared_buffer_map : nullptr, - reinterpret_cast(shared_exec_handle)); +#if MXNET_USE_TENSORRT + // If we've built with TensorRT support we by default return an TRTExecutor. + // Users can override this behaviour via env var, which is useful for example for A/B + // performance testing. + if (dmlc::GetEnv("MXNET_USE_TENSORRT", true)) { + *out = Executor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, + &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, + &arg_stype_map, &grad_req_type_vec, shared_arg_name_set, + &in_arg_vec, &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); + } else { +#endif // MXNET_USE_TENSORRT + *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); +#if MXNET_USE_TENSORRT + } +#endif // MXNET_USE_TENSORRT // copy ndarray ptrs to ret->handles so that front end // can access them @@ -514,345 +529,6 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, API_END(); } -#if MXNET_USE_TENSORRT - -/*! - * \brief - * \param symbol_handle symbol handle - * \param dev_type default device type - * \param dev_id default device id - * \param num_g2c_keys number of group2ctx keys - * \param g2c_keys key list of group2ctx - * \param g2c_dev_types device type list of group2ctx - * \param g2c_dev_ids id list of group2ctx - * \param provided_grad_req_list_len grad_req length provided by users in front-end - * \param provided_grad_req_names grad_req names provided by users in front-end - * \param provided_grad_req_types req types provided by users in front-end - * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes - * \param provided_arg_shape_names name list of provided shapes - * \param provided_arg_shape_data provided shape data - * \param provided_arg_shape_idx provided shape data index - * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes - * \param provided_arg_dtype_names argument name list of provided dtypes - * \param provided_arg_dtypes data of provided dtypes - * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types - * \param provided_arg_stype_names argument name list of provided storage types - * \param provided_arg_stypes data of provided storage types - * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec - * \param shared_arg_name_list parameter name list passed from _bind_ith_exec - * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec - * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec - * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec - * \param updated_shared_buffer_name_list updated shared data array names after binding - * \param updated_shared_buffer_handle_list updated shared data arrays after binding - * \param num_in_args number of input arguments of this sym - * \param in_args list_arguments associated with the current executor - * \param arg_grads list of gradients of in_args associated with the current executor - * \param num_aux_states number of aux states of this sym - * \param aux_states list_auxiliary_states associated with the current executor - * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec - * \param out the handle of the executor to be created - */ -int MXExecutorTensorRTBind(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const mx_uint num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - const mx_uint provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - const mx_uint num_provided_arg_shapes, - const char** provided_arg_shape_names, - const mx_uint* provided_arg_shape_data, - const mx_uint* provided_arg_shape_idx, - const mx_uint num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - const mx_uint num_provided_arg_stypes, - const char** provided_arg_stype_names, - const int* provided_arg_stypes, - const mx_uint num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - mx_uint* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - mx_uint* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); - API_BEGIN(); - nnvm::Symbol *sym = static_cast(symbol_handle); - - // get in_arg names - std::vector in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); - std::vector aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); - - // attr_dict for setting up type_dict and arg/aux ctx - std::unordered_map> attr_dict; - if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) { - std::vector> attrs = - sym->ListAttrsRecursive(); - attr_dict.reserve(attrs.size()); - for (const auto& tp : attrs) { - attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp); - } - } - - // setup arg_dtype_map - std::unordered_map arg_dtype_map; - if (nullptr == provided_arg_dtypes) { // use attr_dict - for (const auto& arg_name : in_arg_names) { - const auto it = attr_dict.find(arg_name); - if (it == attr_dict.end() || !it->second.count("__dtype__")) { - arg_dtype_map[arg_name] = mshadow::kFloat32; - } - } - } else { // use user input type_dict - // create dtype map for in_args and aux_states - arg_dtype_map.reserve(num_provided_arg_dtypes); - for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { - arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; - } - } - - // setup arg_stype_map - std::unordered_map arg_stype_map; - if (nullptr == provided_arg_stypes) { // use attr_dict - for (const auto& arg_name : in_arg_names) { - const auto it = attr_dict.find(arg_name); - if (it == attr_dict.end() || !it->second.count("__storage_type__")) { - arg_stype_map[arg_name] = kDefaultStorage; - } - } - } else { // use user input type_dict - // create stype map for in_args and aux_states - arg_stype_map.reserve(num_provided_arg_stypes); - for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { - arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; - } - } - - // create default ctx - Context ctx = Context::Create(static_cast(dev_type), dev_id); - // create ctx map - std::map ctx_map; - std::vector in_arg_ctx_vec(in_arg_names.size(), ctx); - std::vector aux_state_ctx_vec(aux_state_names.size(), ctx); - if (nullptr != g2c_keys) { // use user input group2ctx dict - for (mx_uint i = 0; i < num_g2c_keys; ++i) { - ctx_map[g2c_keys[i]] = Context::Create( - static_cast(g2c_dev_types[i]), g2c_dev_ids[i]); - } - - // initialize in_arg_ctx_vec using group2ctx if there are any - for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) { - const auto it1 = attr_dict.find(in_arg_names[i]); - if (it1 != attr_dict.end()) { - const auto it2 = it1->second.find("__ctx_group__"); - if (it2 != it1->second.end()) { - const auto it3 = ctx_map.find(it2->second); - if (it3 != ctx_map.end()) { - in_arg_ctx_vec[i] = it3->second; - } - } - } - } - - // initialize aux_state_ctx_vec using group2ctx if there are any - for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) { - const auto it1 = attr_dict.find(aux_state_names[i]); - if (it1 != attr_dict.end()) { - const auto it2 = it1->second.find("__ctx_group__"); - if (it2 != it1->second.end()) { - const auto it3 = ctx_map.find(it2->second); - if (it3 != ctx_map.end()) { - aux_state_ctx_vec[i] = it3->second; - } - } - } - } - } - - // create provided_grad_req_map - const std::map req_map = - {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}}; - std::unordered_map provided_grad_req_map; - std::string grad_req_type; - if (0 == provided_grad_req_list_len - && nullptr == provided_grad_req_names - && nullptr != provided_grad_req_types) { // string, grad_req='write' - CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U) - << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; " - "only \'null\', \'write\', and \'add\' " - "are supported"; - grad_req_type = "string"; - } else if (provided_grad_req_list_len > 0 - && nullptr == provided_grad_req_names - && nullptr != provided_grad_req_types) { // list, grad_req=['null', 'write'] - grad_req_type = "list"; - CHECK_EQ(provided_grad_req_list_len, in_arg_names.size()) - << "The length of grad_req list does not match the number of input arguments in " - "simple_bind, expected " << in_arg_names.size() << ", provided " << - provided_grad_req_list_len; - } else if (provided_grad_req_list_len > 0 - && nullptr != provided_grad_req_names - && nullptr != provided_grad_req_types) { // dict, grad_req=['lhs': 'null', 'rhs': - // 'write'] - grad_req_type = "dict"; - provided_grad_req_map.reserve(provided_grad_req_list_len); - for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) { - CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) - << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " - "only \'null\', \'write\', and \'add\' " - "are supported"; - provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i]; - } - } else { // grad_req is None - grad_req_type = "none"; - } - - // initialize arg_grad_ctx_vec and grad_req_type_vec - std::vector arg_grad_ctx_vec(in_arg_names.size(), ctx); - std::vector grad_req_type_vec(in_arg_names.size(), kNullOp); - if ("none" != grad_req_type) { - for (size_t i = 0; i < in_arg_names.size(); ++i) { - OpReqType cur_req = kNullOp; - if ("string" == grad_req_type) { - cur_req = req_map.at(provided_grad_req_types[0]); - } else if ("list" == grad_req_type) { - CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U) - << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; " - "only \'null\', \'write\', and " - "\'add\' are supported"; - cur_req = req_map.at(provided_grad_req_types[i]); - } else if ("dict" == grad_req_type) { - const auto it = provided_grad_req_map.find(in_arg_names[i]); - if (it != provided_grad_req_map.end()) { - cur_req = req_map.at(it->second); - } - } - if (kNullOp != cur_req) { - arg_grad_ctx_vec[i] = in_arg_ctx_vec[i]; - grad_req_type_vec[i] = static_cast(cur_req); - } - } - } - - // create shape map for in_args and aux_states - std::unordered_map arg_shape_map(num_provided_arg_shapes); - for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) { - auto p = arg_shape_map.emplace(provided_arg_shape_names[i], - TShape(provided_arg_shape_data+provided_arg_shape_idx[i], - provided_arg_shape_data+provided_arg_shape_idx[i+1])); - CHECK(p.second) << "Duplicate shapes are provided for argument " - << provided_arg_shape_names[i] << " in simple_bind"; - } - - // create para name set for sharing data array memory - std::unordered_set shared_arg_name_set(num_shared_arg_names); - for (mx_uint i = 0; i < num_shared_arg_names; ++i) { - shared_arg_name_set.insert(shared_arg_name_list[i]); - } - - // create shared_buffer_map - std::unordered_map shared_buffer_map; - bool use_shared_buffer = (*shared_buffer_len >= 0); - if (*shared_buffer_len > 0) { - // create shared_buffer_map - shared_buffer_map.reserve(*shared_buffer_len); - NDArray** shared_buffer_ptrs = - reinterpret_cast(shared_buffer_handle_list); - for (int i = 0; i < *shared_buffer_len; ++i) { - shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]); - } - } - - // create temporary place holders for the initialized NDArrays - // to be passed back to front end - std::vector in_arg_vec; - std::vector arg_grad_vec; - std::vector aux_state_vec; - - *out = Executor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, - &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map, - &grad_req_type_vec, shared_arg_name_set, &in_arg_vec, - &arg_grad_vec, &aux_state_vec, - use_shared_buffer ? &shared_buffer_map : nullptr, - reinterpret_cast(shared_exec_handle)); - - // copy ndarray ptrs to ret->handles so that front end - // can access them - ret->ret_handles.clear(); - ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size() - +shared_buffer_map.size()); - size_t nd_idx = 0; - for (const auto& nd : in_arg_vec) { - if (nd.is_none()) { - LOG(FATAL) << "Input argument NDArray cannot be un-allocated"; - } - ret->ret_handles.push_back(new NDArray(nd)); - } - if (in_arg_vec.size() > 0) { - *num_in_args = in_arg_vec.size(); - *in_args = &(ret->ret_handles[nd_idx]); - nd_idx = ret->ret_handles.size(); - } - - for (const auto& nd : arg_grad_vec) { - if (nd.is_none()) { - ret->ret_handles.push_back(nullptr); - } else { - ret->ret_handles.push_back(new NDArray(nd)); - } - } - if (arg_grad_vec.size() > 0) { - *arg_grads = &(ret->ret_handles[nd_idx]); - nd_idx = ret->ret_handles.size(); - } - - for (const auto& nd : aux_state_vec) { - if (nd.is_none()) { - LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated"; - } - ret->ret_handles.push_back(new NDArray(nd)); - } - if (aux_state_vec.size() > 0) { - *num_aux_states = aux_state_vec.size(); - *aux_states = &(ret->ret_handles[nd_idx]); - nd_idx = ret->ret_handles.size(); - } - - if (use_shared_buffer) { - ret->ret_vec_str.clear(); - ret->ret_vec_str.reserve(shared_buffer_map.size()); - ret->ret_vec_charp.clear(); - ret->ret_vec_charp.reserve(shared_buffer_map.size()); - for (const auto& kv : shared_buffer_map) { - if (kv.second.is_none()) { - LOG(FATAL) << "Shared data NDArray cannot be un-allocated"; - } - ret->ret_handles.push_back(new NDArray(kv.second)); - ret->ret_vec_str.emplace_back(kv.first); - ret->ret_vec_charp.push_back(ret->ret_vec_str.back().c_str()); - } - *shared_buffer_len = shared_buffer_map.size(); - *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]); - *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]); - } - - API_END(); -} - -#endif // MXNET_USE_TENSORRT - int MXExecutorReshape(int partial_shaping, int allow_up_sizing, int dev_type, diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 433e7aadce00..f9c286b596be 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -57,7 +57,7 @@ GraphExecutor::~GraphExecutor() { } inline NDArray GraphExecutor::InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype) { + const Context &ctx, const int dtype) { // NDArray with default storage if (stype == kDefaultStorage) { NDArray ret(shape, ctx, false, dtype); @@ -69,8 +69,8 @@ inline NDArray GraphExecutor::InitZeros(const NDArrayStorageType stype, const TS } inline void GraphExecutor::EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec) { + const Context &ctx, const int dtype, + std::vector *vec) { // NDArray with default storage if (stype == kDefaultStorage) { vec->emplace_back(shape, ctx, false, dtype); @@ -438,8 +438,8 @@ Graph GraphExecutor::AssignContext(Graph g, } void GraphExecutor::HandleInferShapeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes) { + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes) { int cnt = 10; std::ostringstream oss; for (size_t i = 0; i < num_forward_inputs; ++i) { @@ -461,8 +461,8 @@ void GraphExecutor::HandleInferShapeError(const size_t num_forward_inputs, } void GraphExecutor::HandleInferTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::DTypeVector& inferred_dtypes) { + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes) { int cnt = 10; std::ostringstream oss; for (size_t i = 0; i < num_forward_inputs; ++i) { @@ -484,8 +484,8 @@ void GraphExecutor::HandleInferTypeError(const size_t num_forward_inputs, } void GraphExecutor::HandleInferStorageTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const StorageTypeVector& inferred_stypes) { + const nnvm::IndexedGraph& idx, + const StorageTypeVector& inferred_stypes) { int cnt = 10; std::ostringstream oss; for (size_t i = 0; i < num_forward_inputs; ++i) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 13813eb5b8d2..05429b2508a5 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -164,19 +164,19 @@ class GraphExecutor : public Executor { // Initialize in_args, arg_grads and aux_states with // shared_buffer and shared_exec virtual void InitArguments(const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes, - const nnvm::DTypeVector& inferred_dtypes, - const StorageTypeVector& inferred_stypes, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - const std::unordered_set& shared_arg_names, - const Executor* shared_exec, - std::unordered_map* shared_buffer, - std::vector* in_arg_vec, - std::vector* arg_grad_vec, - std::vector* aux_state_vec); + const nnvm::ShapeVector& inferred_shapes, + const nnvm::DTypeVector& inferred_dtypes, + const StorageTypeVector& inferred_stypes, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + const std::unordered_set& shared_arg_names, + const Executor* shared_exec, + std::unordered_map* shared_buffer, + std::vector* in_arg_vec, + std::vector* arg_grad_vec, + std::vector* aux_state_vec); // internal initialization of the graph for simple bind Graph InitGraph(nnvm::Symbol symbol, @@ -213,21 +213,26 @@ class GraphExecutor : public Executor { void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); - + // prints a helpful message after shape inference errors in executor. static void HandleInferShapeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes); + // prints a helpful message after type inference errors in executor. static void HandleInferTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::DTypeVector& inferred_dtypes); + // prints a helpful message after storage type checking errors in executor. static void HandleInferStorageTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const StorageTypeVector& inferred_stypes); + // helper to initialize an NDArray to all zeros. static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype); + // helper to add a NDArray of zeros to a std::vector. static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype, std::vector *vec); + // helper to reshape an NDArray of certain shape if it doesn't already exist. static NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index be9e3d7d765a..3df9f0474bad 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -30,7 +30,7 @@ namespace mxnet { namespace exec { /*! - * \brief GraphExecutor initializer for simple bind flow in + * \brief TrtGraphExecutor initializer for simple bind flow in * which only certain input shapes and dtypes are provided by users. * The initializer uses these shapes and dtypes to perform * shape and dtype inferences, and then create NDArrays @@ -40,8 +40,14 @@ namespace exec { * In front end, if the simple_bind flow is trigger by * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup * and shared executor will be taken into account in creating - * NDArrays for in_args, arg_grads, and aux_states for resuing + * NDArrays for in_args, arg_grads, and aux_states for reusing * already allocated memory. + * + * This version of an executor exports the computation graph to TensorRT make use of fused + * kernels and other runtime enhancements. TRT will compile the sub-graphs to executable fused + * operators without intervention from the user. Operators in the original graph that are not + * supported by TRT will continue to be executed normally by MXNet. + * */ void TrtGraphExecutor::Init(nnvm::Symbol symbol, const Context& default_ctx, @@ -256,30 +262,13 @@ void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } } else { // !shared_arg_names.count(arg_name) // model parameter, row_sparse ndarray sharing enabled - bool enable_row_sparse_sharing = true; - - if (use_tensorrt_ && !need_grad_) { -#if MXNET_USE_TENSORRT - auto it = shared_buffer->find(arg_name); - if (it != shared_buffer->end()) { - in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); - } else { - in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, - in_arg_ctxes[arg_top], inferred_dtype))); - } -#else - LOG(FATAL) << "Env. var. MXNET_USE_TENSORRT = 1 set, but MXNet wasn't " - << "built with TensorRT. Add USE_TENSORRT = 1 in config.mk"; -#endif + auto it = shared_buffer->find(arg_name); + if (it != shared_buffer->end()) { + in_arg_vec->push_back(std::move(it->second.Copy(in_arg_ctxes[arg_top]))); } else { - if (use_tensorrt_) { - LOG(WARNING) << "USE_TENSORRT=1 but grads required. Running without TensorRT"; - } - in_arg_vec->emplace_back(ReshapeOrCreate( - arg_name, inferred_shape, inferred_dtype, inferred_stype, - in_arg_ctxes[arg_top], shared_buffer, enable_row_sparse_sharing)); + in_arg_vec->push_back(std::move(InitZeros(inferred_stype, inferred_shape, + in_arg_ctxes[arg_top], inferred_dtype))); } - // gradient for model parameter, row_sparse ndarray sharing disabled if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h index b201f0623462..4fd1d517b495 100644 --- a/src/executor/trt_graph_executor.h +++ b/src/executor/trt_graph_executor.h @@ -51,8 +51,9 @@ class TrtGraphExecutor : public GraphExecutor { std::unordered_map* shared_buffer = nullptr, Executor* shared_exec = nullptr, const nnvm::NodeEntryMap& feed_dict - = nnvm::NodeEntryMap()); + = nnvm::NodeEntryMap()); + // Returns symbol representing the TRT optimized graph for comparison purposes. nnvm::Symbol GetOptimizedSymbol(); protected: diff --git a/tests/.gitignore b/tests/.gitignore index 6100add5e450..3e5eed695f0a 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -1,2 +1,2 @@ *_unittest -*.gz \ No newline at end of file +*.gz diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py index 451983b33e13..2c64b96fafca 100644 --- a/tests/python/tensorrt/test_cvnets.py +++ b/tests/python/tensorrt/test_cvnets.py @@ -25,8 +25,9 @@ from mxnet.gluon.data.vision import transforms -def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, ctx=mx.gpu(0), - batch_size=128): + +def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) h, w = 32, 32 net = gluoncv.model_zoo.get_model(model_name, pretrained=True) data = mx.sym.var('data') @@ -35,9 +36,9 @@ def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, ctx=mx. out = net(data) softmax = mx.sym.SoftmaxOutput(out, name='softmax') all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - executor = mx.contrib.tensorrt.optimize_graph(softmax, ctx=ctx, data=(batch_size, 3, h, w), - softmax_label=(batch_size,), grad_req='null', - shared_buffer=all_params, force_rebind=True) + executor = softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), + softmax_label=( batch_size,), grad_req='null', + shared_buffer=all_params, force_rebind=True) else: # Convert gluon model to Symbolic net.hybridize() @@ -45,13 +46,12 @@ def get_classif_model(model_name='cifar_resnet56_v1', use_tensorrt=True, ctx=mx. net.export(model_name) symbol, arg_params, aux_params = mx.model.load_checkpoint(model_name, 0) executor = symbol.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), - softmax_label=(batch_size,)) + softmax_label=(batch_size,)) executor.copy_params_from(arg_params, aux_params) return executor -def cifar10_infer(data_dir='./data', model_name='cifar_resnet56_v1', use_tensorrt=True, - ctx=mx.gpu(0), fp16_for_fp32_graph=False, batch_size=128, num_workers=1): +def cifar10_infer(model_name, use_tensorrt, num_workers, ctx=mx.gpu(0), batch_size=128): executor = get_classif_model(model_name, use_tensorrt, ctx, batch_size) num_ex = 10000 @@ -125,47 +125,51 @@ def run_experiment_for(model_name, batch_size, num_workers): def test_tensorrt_on_cifar_resnets(batch_size=32, tolerance=0.1, num_workers=1): - models = [ - 'cifar_resnet20_v1', - 'cifar_resnet56_v1', - 'cifar_resnet110_v1', - 'cifar_resnet20_v2', - 'cifar_resnet56_v2', - 'cifar_resnet110_v2', - 'cifar_wideresnet16_10', - 'cifar_wideresnet28_10', - 'cifar_wideresnet40_8', - 'cifar_resnext29_16x64d' - ] - - num_models = len(models) - - speedups = np.zeros(num_models, dtype=np.float32) - acc_diffs = np.zeros(num_models, dtype=np.float32) - - test_start = time() - - for idx, model in enumerate(models): - speedup, acc_diff = run_experiment_for(model, batch_size, num_workers) - speedups[idx] = speedup - acc_diffs[idx] = acc_diff - assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % ( - tolerance, model) - - print("Perf and correctness checks run on the following models:") - print(models) - mean_speedup = np.mean(speedups) - std_speedup = np.std(speedups) - print("\nSpeedups:") - print(speedups) - print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) - print("Mean speedup: %.2f" % mean_speedup) - print("St. dev. of speedups: %.2f" % std_speedup) - print("\nAcc. differences: %s" % str(acc_diffs)) - - test_duration = time() - test_start - - print("Test duration: %.2f seconds" % test_duration) + original_try_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + models = [ + 'cifar_resnet20_v1', + 'cifar_resnet56_v1', + 'cifar_resnet110_v1', + 'cifar_resnet20_v2', + 'cifar_resnet56_v2', + 'cifar_resnet110_v2', + 'cifar_wideresnet16_10', + 'cifar_wideresnet28_10', + 'cifar_wideresnet40_8', + 'cifar_resnext29_16x64d' + ] + + num_models = len(models) + + speedups = np.zeros(num_models, dtype=np.float32) + acc_diffs = np.zeros(num_models, dtype=np.float32) + + test_start = time() + + for idx, model in enumerate(models): + speedup, acc_diff = run_experiment_for(model, batch_size, num_workers) + speedups[idx] = speedup + acc_diffs[idx] = acc_diff + assert acc_diff < tolerance, "Accuracy difference between MXNet and TensorRT > %.2f%% for model %s" % ( + tolerance, model) + + print("Perf and correctness checks run on the following models:") + print(models) + mean_speedup = np.mean(speedups) + std_speedup = np.std(speedups) + print("\nSpeedups:") + print(speedups) + print("Speedup range: [%.2f, %.2f]" % (np.min(speedups), np.max(speedups))) + print("Mean speedup: %.2f" % mean_speedup) + print("St. dev. of speedups: %.2f" % std_speedup) + print("\nAcc. differences: %s" % str(acc_diffs)) + + test_duration = time() - test_start + + print("Test duration: %.2f seconds" % test_duration) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_try_value) if __name__ == '__main__': diff --git a/tests/python/tensorrt/test_cycle.py b/tests/python/tensorrt/test_cycle.py index a3319ea02f25..25f515a106a6 100644 --- a/tests/python/tensorrt/test_cycle.py +++ b/tests/python/tensorrt/test_cycle.py @@ -58,9 +58,10 @@ def test_simple_cycle(): 'B_bias': mx.nd.zeros([10]), } - executor = mx.contrib.tensorrt.optimize_graph(C, ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), + executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,), shared_buffer=arg_params, grad_req='null', force_rebind=True) - assert has_no_cycle(executor.optimized_symbol), "The graph optimized by TRT contains a cycle" + optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor) + assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle" if __name__ == '__main__': diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index e497faeabc7a..8e2730bcc7d8 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -21,25 +21,21 @@ from common import * from lenet5_common import get_iters + def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size, use_tensorrt): """Run inference with either MXNet or TensorRT""" + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) shared_buffer = merge_dicts(arg_params, aux_params) if not use_tensorrt: shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) - executor = sym.simple_bind(ctx=mx.gpu(0), + + executor = sym.simple_bind(ctx=mx.gpu(0), data=(batch_size,) + mnist['test_data'].shape[1:], softmax_label=(batch_size,), shared_buffer=shared_buffer, grad_req='null', force_rebind=True) - else: - executor = mx.contrib.tensorrt.optimize_graph(sym, ctx=mx.gpu(0), - data=(batch_size,) + mnist['test_data'].shape[1:], - softmax_label=(batch_size,), - shared_buffer=shared_buffer, - grad_req='null', - force_rebind=True) # Get this value from all_test_labels # Also get classes from the dataset @@ -64,37 +60,42 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz return percentage + def test_tensorrt_inference(): """Run LeNet-5 inference comparison between MXNet and TensorRT.""" - check_tensorrt_installation() - mnist = mx.test_utils.get_mnist() - num_epochs = 10 - batch_size = 128 - model_name = 'lenet5' - model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") - model_file = '%s/%s-symbol.json' % (model_dir, model_name) - params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) - - _, _, _, all_test_labels = get_iters(mnist, batch_size) - - # Load serialized MXNet model (model-symbol.json + model-epoch.params) - sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) - - print("LeNet-5 test") - print("Running inference in MXNet") - mx_pct = run_inference(sym, arg_params, aux_params, mnist, - all_test_labels, batch_size=batch_size) - - print("Running inference in MXNet-TensorRT") - trt_pct = run_inference(sym, arg_params, aux_params, mnist, - all_test_labels, batch_size=batch_size) - - print("MXNet accuracy: %f" % mx_pct) - print("MXNet-TensorRT accuracy: %f" % trt_pct) - - assert abs(mx_pct - trt_pct) < 1e-2, \ - """Diff. between MXNet & TensorRT accuracy too high: - MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct) + original_try_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + check_tensorrt_installation() + mnist = mx.test_utils.get_mnist() + num_epochs = 10 + batch_size = 128 + model_name = 'lenet5' + model_dir = os.getenv("LENET_MODEL_DIR", "/tmp") + model_file = '%s/%s-symbol.json' % (model_dir, model_name) + params_file = '%s/%s-%04d.params' % (model_dir, model_name, num_epochs) + + _, _, _, all_test_labels = get_iters(mnist, batch_size) + + # Load serialized MXNet model (model-symbol.json + model-epoch.params) + sym, arg_params, aux_params = mx.model.load_checkpoint(model_name, num_epochs) + + print("LeNet-5 test") + print("Running inference in MXNet") + mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, + batch_size=batch_size, use_tensorrt=False) + + print("Running inference in MXNet-TensorRT") + trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, + batch_size=batch_size, use_tensorrt=True) + + print("MXNet accuracy: %f" % mx_pct) + print("MXNet-TensorRT accuracy: %f" % trt_pct) + + assert abs(mx_pct - trt_pct) < 1e-2, \ + """Diff. between MXNet & TensorRT accuracy too high: + MXNet = %f, TensorRT = %f""" % (mx_pct, trt_pct) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_try_value) if __name__ == '__main__': From 3ea9b89d08889494ae1792fc722fe2c26b26a75a Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 14:10:12 +0200 Subject: [PATCH 24/43] Remove unused declaration --- include/mxnet/executor.h | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 412e12b01e20..0ab04b86a0a1 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -167,28 +167,6 @@ class Executor { shared_data_arrays = nullptr, Executor* shared_exec = nullptr); -#if MXNET_USE_TENSORRT - - static Executor* TensorRTBind(nnvm::Symbol symbol, - const Context& default_ctx, - const std::map& group2ctx, - std::vector *in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - const std::unordered_set& param_names, - std::vector* in_args, - std::vector* arg_grads, - std::vector* aux_states, - std::unordered_map* - shared_data_arrays = nullptr, - Executor* shared_exec = nullptr); - -#endif // MXNET_USE_TENSORRT - /*! * \brief the prototype of user-defined monitor callback */ From d6d2cac66de7c2f46ba922f5d83b8f45a7bbfea1 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 14:10:32 +0200 Subject: [PATCH 25/43] Replace build guards with runtime errors --- include/mxnet/c_api.h | 5 ----- src/c_api/c_api_executor.cc | 12 ++++++++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index de9918a4cf96..58b1b1b4dafe 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1715,17 +1715,12 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping, ExecutorHandle shared_exec, ExecutorHandle *out); - -#if MXNET_USE_TENSORRT - /*! * \brief get optimized graph from graph executor */ MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out); -#endif // MXNET_USE_TENSORRT - /*! * \brief set a call back to notify the completion of operation */ diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 64390c46b067..501769ffdd85 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -616,19 +616,27 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete out); } -#if MXNET_USE_TENSORRT + int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out) { + auto s = new nnvm::Symbol(); API_BEGIN(); + +#if MXNET_USE_TENSORRT auto exec = static_cast(handle); *s = exec->GetOptimizedSymbol(); *out = s; +#else + LOG(FATAL) << "GetOptimizedSymbol may only be used when MXNet is compiled with " + "MXNET_USE_TENSORRT enabled. Please re-compile MXNet with TensorRT support."; +#endif // MXNET_USE_TENSORRT + API_END_HANDLE_ERROR(delete s); } -#endif // MXNET_USE_TENSORRT + int MXExecutorSetMonitorCallback(ExecutorHandle handle, ExecutorMonitorCallback callback, From 2d04aeea7bc25eff0f47ca91757827c872baff32 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 14:13:40 +0200 Subject: [PATCH 26/43] Change default value of TensorRT to off This is change applies to both TensorRT and non-TensorRT builds. --- python/mxnet/contrib/tensorrt.py | 2 +- src/c_api/c_api_executor.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 7720f717d2c5..57afdb6539a0 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -44,7 +44,7 @@ def get_use_tensorrt(): :return: Boolean, true if TensorRT optimization should be applied, False for legacy behaviour. """ - return bool(int(os.environ.get("MXNET_USE_TENSORRT", 1)) == 1) + return bool(int(os.environ.get("MXNET_USE_TENSORRT", 0)) == 1) def get_optimized_symbol(executor): diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 501769ffdd85..d6a15a819a6b 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -447,7 +447,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, // If we've built with TensorRT support we by default return an TRTExecutor. // Users can override this behaviour via env var, which is useful for example for A/B // performance testing. - if (dmlc::GetEnv("MXNET_USE_TENSORRT", true)) { + if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { *out = Executor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, &arg_stype_map, &grad_req_type_vec, shared_arg_name_set, From 449a195679933d3841f7303662f9c3353ff9eda2 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 14:37:36 +0200 Subject: [PATCH 27/43] Warn user when TRT not active at runtime --- src/c_api/c_api_executor.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index d6a15a819a6b..0ed24d2cb16b 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -455,6 +455,14 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); } else { + // Checks to see if this env var has been set to true or false by the user. + // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn + // them and describe further steps. + const int unset_indicator = std::numeric_limits::quiet_NaN(); + if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) { + LOG(INFO) << "TensorRT not enabled by default. Please set MXNET_USE_TENSORRT environment " + "variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) to enable."; + } #endif // MXNET_USE_TENSORRT *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, From ed3673927c9f5c5a533841b853d065d7c8e0d78a Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 15:14:41 +0200 Subject: [PATCH 28/43] Move TensorRTBind declaration, add descriptive errors --- src/c_api/c_api_executor.cc | 18 +++++++------ src/executor/trt_graph_executor.cc | 43 +++++++++++++++++------------- src/executor/trt_graph_executor.h | 20 ++++++++++++++ 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 0ed24d2cb16b..05ab4ee79926 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -448,20 +448,22 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, // Users can override this behaviour via env var, which is useful for example for A/B // performance testing. if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) { - *out = Executor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, &arg_grad_ctx_vec, - &aux_state_ctx_vec, &arg_shape_map, &arg_dtype_map, - &arg_stype_map, &grad_req_type_vec, shared_arg_name_set, - &in_arg_vec, &arg_grad_vec, &aux_state_vec, - use_shared_buffer ? &shared_buffer_map : nullptr, - reinterpret_cast(shared_exec_handle)); + *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec, + &arg_grad_ctx_vec, &aux_state_ctx_vec, + &arg_shape_map, &arg_dtype_map, &arg_stype_map, + &grad_req_type_vec, shared_arg_name_set, + &in_arg_vec, &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, + reinterpret_cast(shared_exec_handle)); } else { // Checks to see if this env var has been set to true or false by the user. // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn // them and describe further steps. const int unset_indicator = std::numeric_limits::quiet_NaN(); if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) { - LOG(INFO) << "TensorRT not enabled by default. Please set MXNET_USE_TENSORRT environment " - "variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) to enable."; + LOG(INFO) << "TensorRT not enabled by default. Please set the MXNET_USE_TENSORRT " + "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) " + "to enable."; } #endif // MXNET_USE_TENSORRT *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index 3df9f0474bad..cbf846c1653f 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -144,6 +144,13 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, // This function can be called by regular bind // operation flow as well. FinishInitGraph(symbol, g, shared_exec, feed_dict); + + if (need_grad_) { + LOG(FATAL) << "You may be attempting to use TensorRT for training. TensorRT is an inference " + "only library. To re-enable legacy MXNet graph execution, which will support " + "training, set the MXNET_USE_TENSORRT environment variable to 0, or call " + "mx.contrib.tensorrt.set_use_tensorrt(False)"; + } } /*! * \brief Initialize in_args, arg_grads, and aux_states @@ -404,24 +411,22 @@ nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() { return ret; } -} // namespace exec - -Executor *Executor::TensorRTBind(nnvm::Symbol symbol, - const Context& default_ctx, - const std::map& group2ctx, - std::vector *in_arg_ctxes, - std::vector* arg_grad_ctxes, - std::vector* aux_state_ctxes, - std::unordered_map* arg_shape_map, - std::unordered_map* arg_dtype_map, - std::unordered_map* arg_stype_map, - std::vector* grad_req_types, - const std::unordered_set& param_names, - std::vector* in_args, - std::vector* arg_grads, - std::vector* aux_states, - std::unordered_map* shared_buffer, - Executor* shared_exec) { +Executor *TrtGraphExecutor::TensorRTBind(nnvm::Symbol symbol, + const Context &default_ctx, + const std::map &group2ctx, + std::vector *in_arg_ctxes, + std::vector *arg_grad_ctxes, + std::vector *aux_state_ctxes, + std::unordered_map *arg_shape_map, + std::unordered_map *arg_dtype_map, + std::unordered_map *arg_stype_map, + std::vector *grad_req_types, + const std::unordered_set ¶m_names, + std::vector *in_args, + std::vector *arg_grads, + std::vector *aux_states, + std::unordered_map *shared_buffer, + Executor *shared_exec) { auto exec = new exec::TrtGraphExecutor(); exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, @@ -432,6 +437,8 @@ Executor *Executor::TensorRTBind(nnvm::Symbol symbol, return exec; } +} // namespace exec + } // namespace mxnet #endif // MXNET_USE_TENSORRT diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h index 4fd1d517b495..cf2bd85d5384 100644 --- a/src/executor/trt_graph_executor.h +++ b/src/executor/trt_graph_executor.h @@ -34,6 +34,25 @@ namespace exec { class TrtGraphExecutor : public GraphExecutor { public: + + static Executor* TensorRTBind(nnvm::Symbol symbol, + const Context& default_ctx, + const std::map& group2ctx, + std::vector *in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* aux_state_ctxes, + std::unordered_map* arg_shape_map, + std::unordered_map* arg_dtype_map, + std::unordered_map* arg_stype_map, + std::vector* grad_req_types, + const std::unordered_set& param_names, + std::vector* in_args, + std::vector* arg_grads, + std::vector* aux_states, + std::unordered_map* + shared_data_arrays = nullptr, + Executor* shared_exec = nullptr); + virtual void Init(nnvm::Symbol symbol, const Context& default_ctx, const std::map& ctx_map, @@ -85,6 +104,7 @@ class TrtGraphExecutor : public GraphExecutor { }; } // namespace exec + } // namespace mxnet #endif // MXNET_USE_TENSORRT From 882d8e5fc95e5303f8158a7bbea3a62f9aeef75f Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 17:09:50 +0200 Subject: [PATCH 29/43] Test TensorRT graph execution, fix bugs --- src/executor/trt_graph_executor.cc | 27 ++++---- .../python/tensorrt/test_training_warning.py | 65 +++++++++++++++++++ 2 files changed, 79 insertions(+), 13 deletions(-) create mode 100644 tests/python/tensorrt/test_training_warning.py diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index cbf846c1653f..1b0a25c857fa 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -69,6 +69,20 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, symbol = symbol.Copy(); nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, *in_arg_ctxes, *arg_grad_ctxes, *aux_state_ctxes, *grad_req_types); + + if (need_grad_) { + LOG(FATAL) << "You may be attempting to use TensorRT for training. TensorRT is an inference " + "only library. To re-enable legacy MXNet graph execution, which will support " + "training, set the MXNET_USE_TENSORRT environment variable to 0, or call " + "mx.contrib.tensorrt.set_use_tensorrt(False)"; + } + + if (shared_buffer == nullptr || shared_buffer->empty()) { + LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty. " + << "Please provide weights and other parameters, such as " + << "BatchNorm moments, via the shared_buffer, during simple bind call."; + } + // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation // should do this differently. @@ -113,11 +127,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, g.GetAttr("storage_type")); } - if (shared_buffer->empty()) { - LOG(FATAL) << "MXNET_USE_TENSORRT = 1 but shared_buffer is empty." - << "Please provide weights and other parameters, such as " - << "BatchNorm moments, via the shared_buffer, during simple bind call."; - } auto trt_groups = GetTrtCompatibleSubsets(g, shared_buffer); for (auto trt_group : trt_groups) { if (trt_group.size() > 1) { @@ -144,13 +153,6 @@ void TrtGraphExecutor::Init(nnvm::Symbol symbol, // This function can be called by regular bind // operation flow as well. FinishInitGraph(symbol, g, shared_exec, feed_dict); - - if (need_grad_) { - LOG(FATAL) << "You may be attempting to use TensorRT for training. TensorRT is an inference " - "only library. To re-enable legacy MXNet graph execution, which will support " - "training, set the MXNET_USE_TENSORRT environment variable to 0, or call " - "mx.contrib.tensorrt.set_use_tensorrt(False)"; - } } /*! * \brief Initialize in_args, arg_grads, and aux_states @@ -172,7 +174,6 @@ void TrtGraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, std::vector* in_arg_vec, std::vector* arg_grad_vec, std::vector* aux_state_vec) { - bool use_tensorrt_ = true; // initialize in_args, arg_grads, and aux_states and populate grad_store_ data_entry_.resize(idx.num_node_entries()); size_t arg_top = 0, aux_top = 0; diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py new file mode 100644 index 000000000000..3008a4234b58 --- /dev/null +++ b/tests/python/tensorrt/test_training_warning.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import gluoncv +import mxnet as mx + +from tests.python.unittest.common import assertRaises + + +def test_training_without_trt(): + run_resnet(is_train=True, use_tensorrt=False) + + +def test_inference_without_trt(): + run_resnet(is_train=False, use_tensorrt=False) + + +def test_training_with_trt(): + assertRaises(RuntimeError, run_resnet, is_train=True, use_tensorrt=True) + + +def test_inference_with_trt(): + run_resnet(is_train=False, use_tensorrt=True) + + +def run_resnet(is_train, use_tensorrt): + original_trt_value = mx.contrib.tensorrt.get_use_tensorrt() + try: + mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) + ctx = mx.gpu(0) + batch_size = 1 + h = 32 + w = 32 + model_name = 'cifar_resnet20_v1' + resnet = gluoncv.model_zoo.get_model(model_name, pretrained=True) + data = mx.sym.var('data') + out = resnet(data) + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) + if is_train: + grad_req = 'write' + else: + grad_req = 'null' + softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), + shared_buffer=all_params, force_rebind=True, grad_req=grad_req) + finally: + mx.contrib.tensorrt.set_use_tensorrt(original_trt_value) + + +if __name__ == '__main__': + import nose + nose.runmodule() From 95a7955e4ad8a5cc30dcf022a6acfa8867c6150d Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 17:25:38 +0200 Subject: [PATCH 30/43] Fix lint and whitespace issues --- src/c_api/c_api_executor.cc | 3 --- src/executor/trt_graph_executor.h | 1 - 2 files changed, 4 deletions(-) diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index 05ab4ee79926..b99350525bfa 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -626,11 +626,8 @@ int MXExecutorReshape(int partial_shaping, API_END_HANDLE_ERROR(delete out); } - - int MXExecutorGetOptimizedSymbol(ExecutorHandle handle, SymbolHandle *out) { - auto s = new nnvm::Symbol(); API_BEGIN(); diff --git a/src/executor/trt_graph_executor.h b/src/executor/trt_graph_executor.h index cf2bd85d5384..96ac4426270a 100644 --- a/src/executor/trt_graph_executor.h +++ b/src/executor/trt_graph_executor.h @@ -34,7 +34,6 @@ namespace exec { class TrtGraphExecutor : public GraphExecutor { public: - static Executor* TensorRTBind(nnvm::Symbol symbol, const Context& default_ctx, const std::map& group2ctx, From 0307467e6410eb70d27a85555641d5245182072b Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 17:31:01 +0200 Subject: [PATCH 31/43] Fix typo --- src/executor/trt_graph_executor.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index 1b0a25c857fa..a73013b49fca 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -395,7 +395,7 @@ Graph TrtGraphExecutor::ReinitGraph(Graph&& g, const Context &default_ctx, /*! - * \brief Return the "optimized" symbol contained in _graph. + * \brief Return the "optimized" symbol contained in the graph. * For optimization pass such as TensorRT pass */ nnvm::Symbol TrtGraphExecutor::GetOptimizedSymbol() { From 8504319d11a8823f134c721dbecabf2d7359e3dd Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 17:41:19 +0200 Subject: [PATCH 32/43] Removed default value for set_use_tensorrt --- python/mxnet/contrib/tensorrt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 57afdb6539a0..37d1f6843a76 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -27,7 +27,7 @@ from ..base import check_call -def set_use_tensorrt(status=False): +def set_use_tensorrt(status): """ Set an environment variable which will enable or disable the use of TensorRT in the backend. Note: this is useful for A/B testing purposes. From 55dd4226f2f4d2acf10a73b3783d936ada2a518c Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 17:41:38 +0200 Subject: [PATCH 33/43] Improved documentation and fixed spacing issues --- python/mxnet/contrib/tensorrt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 37d1f6843a76..11bdecc44868 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -48,12 +48,13 @@ def get_use_tensorrt(): def get_optimized_symbol(executor): - """Get optimized symbol. + """ + Take an executor's underlying symbol graph and return its generated optimized version. Parameters ---------- executor : - An executor for which you want to see an optimized symbol. Getting an optimized symbol + An executor for which you want to see an optimized symbol. Getting an optimized symbol is useful to compare and verify the work TensorRT has done against a legacy behaviour. Returns @@ -67,6 +68,6 @@ def get_optimized_symbol(executor): result = Symbol(handle=handle) return result except MXNetError: - logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure' + logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure ' 'build was compiled with MXNET_USE_TENSORRT enabled.') raise From ec9d3ea035f217e93c7093f10e23bd02c61be42f Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 22:11:04 +0200 Subject: [PATCH 34/43] Move static exec funcs to util files --- src/common/exec_utils.h | 252 ++++++++++++++++++++++++++++++ src/common/utils.h | 27 ++++ src/executor/graph_executor.cc | 273 --------------------------------- src/executor/graph_executor.h | 38 ----- 4 files changed, 279 insertions(+), 311 deletions(-) diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 816599b955c1..1757c9499d7b 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -24,6 +24,7 @@ #ifndef MXNET_COMMON_EXEC_UTILS_H_ #define MXNET_COMMON_EXEC_UTILS_H_ +#include #include #include #include @@ -366,6 +367,257 @@ inline void LogInferStorage(const nnvm::Graph& g) { } } +// prints a helpful message after shape inference errors in executor. +static void HandleInferShapeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::ShapeVector& inferred_shapes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const TShape& inferred_shape = inferred_shapes[eid]; + if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_shape << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " + "(0s means unknown dimensions). Please consider providing them as inputs:\n" + << oss.str(); +} + +// prints a helpful message after type inference errors in executor. +static void HandleInferTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_dtype = inferred_dtypes[eid]; + if (inferred_dtype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_dtype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " + "(-1 means unknown dtype). Please consider providing them as inputs:\n" + << oss.str(); +} + +// prints a helpful message after storage type checking errors in executor. +static void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const StorageTypeVector& inferred_stypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_stype = inferred_stypes[eid]; + if (inferred_stype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << common::stype_string(inferred_stype) << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments " + "(-1 means unknown stype). Please consider providing them as inputs:\n" + << oss.str(); +} + +/*! + * \brief If the requested ndarray's shape size is less than + * the corresponding shared_data_array's shape size and the + * storage type is shareable, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. + * Shareable storages include both default storage and row_sparse storage + * if enable_row_sparse_sharing is `True`, otherwise default storage only. + */ +static NDArray ReshapeOrCreate(const std::string& name, + const TShape& dest_arg_shape, + const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, + const Context& ctx, + std::unordered_map* shared_buffer, + bool enable_row_sparse_sharing) { + bool stype_shareable = dest_arg_stype == kDefaultStorage; + if (enable_row_sparse_sharing) { + stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; + } + auto it = shared_buffer->find(name); + if (it != shared_buffer->end()) { + // check if size is large enough for sharing + bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size(); + if (size_shareable && stype_shareable) { // memory can be reused + CHECK_EQ(it->second.dtype(), dest_arg_dtype) + << "Requested arg array's dtype does not match that of the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), dest_arg_stype) + << "Requested arg array's stype does not match that of the reusable ndarray"; + return it->second.Reshape(dest_arg_shape); + } else if (stype_shareable) { + LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape + << ", which is larger than already allocated shape " << it->second.shape() + << ". Need to re-allocate. Consider putting default bucket key to be " + << "the bucket taking the largest input for better memory sharing."; + // size is not large enough, creating a larger one for sharing + // the NDArrays in shared_buffer are guaranteed to be of shareable storages + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + return it->second; + } else { + // not shareable storage + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } + } else { + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + if (stype_shareable) { + shared_buffer->emplace(name, ret); + } + return ret; + } // if (it != shared_buffer->end()) +} + +/*! + * \brief Assign context to the graph. + * This is triggered by both simple_bind and bind flows. + */ +static Graph AssignContext(Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + size_t num_forward_inputs, + size_t num_forward_outputs) { + const auto& idx = g.indexed_graph(); + const auto& mutable_nodes = idx.mutable_input_nodes(); + // default use default context. + if (ctx_map.size() == 0) { + g.attrs["context"] = std::make_shared( + ContextVector(idx.num_nodes(), default_ctx)); + for (const auto& x : in_arg_ctxes) { + CHECK(x == default_ctx) + << "Input array is in " << x << " while binding with ctx=" << default_ctx + << ". All arguments must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; + } + for (const auto& x : arg_grad_ctxes) { + CHECK(x == default_ctx) + << "Gradient array is in " << x << " while binding with ctx=" + << default_ctx << ". All gradients must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; + } + return g; + } + + // otherwise, use context assignment. + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id + nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id + nnvm::DeviceAssignMap device_map; // map arg name to device id + + // loop through the user input ctx_map and + // populate maps and lists + for (auto &kv : ctx_map) { + if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one + ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx + ctx_list.push_back(kv.second); // save ctx to the list + } + // assign device id to to the arg name with the corresponding ctx + device_map[kv.first] = ctx2id.at(kv.second); + } + + // loop through all the rest of input nodes not specified + // in the ctx_map and populate maps and lists + size_t arg_top = 0, aux_top = 0; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + Context ctx; + if (mutable_nodes.count(nid)) { // aux node is mutable + CHECK_LT(aux_top, aux_state_ctxes.size()); + ctx = aux_state_ctxes[aux_top]; + ++aux_top; + } else { // regular input node is immutable + CHECK_LT(arg_top, in_arg_ctxes.size()); + ctx = in_arg_ctxes[arg_top]; + ++arg_top; + } + if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id + ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id + ctx_list.push_back(ctx); // save the current ctx in the list + } + device[nid] = ctx2id.at(ctx); // assign device id to the current node + } + + // loop through backward input nodes and populate maps and lists + // the backward input nodes is the gradient of the loss wrt the output + size_t arg_grad_offset = 0; + // keep an offset into the arg_grad_ctxes vector, + // since g.outputs exclude arg_grad whose req == null + CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) + << "insufficient number of grad_reqs"; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + const uint32_t nid = idx.outputs()[i].node_id; + Context ctx = arg_grad_ctxes[arg_grad_offset]; + if (ctx2id.count(ctx) == 0) { + ctx2id[ctx] = static_cast(ctx_list.size()); + ctx_list.push_back(ctx); + } + int devid = ctx2id.at(ctx); + if (device[nid] != -1) { + CHECK_EQ(device[nid], devid) << "device of same output not equal to each other"; + } else { + device[nid] = devid; + } + } + + g.attrs["device"] = std::make_shared(std::move(device)); + g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); + const auto& assigned_device = g.GetAttr("device"); + + ContextVector vcontext; + for (size_t i = 0; i < assigned_device.size(); ++i) { + if (assigned_device[i] == -1) { + vcontext.push_back(default_ctx); + } else { + vcontext.push_back(ctx_list[assigned_device[i]]); + } + } + + // after device planning, we should check again + // if the assigned device of gradient node + // corresponds to storage of grads + auto &new_idx = g.indexed_graph(); + arg_grad_offset = 0; + for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { + while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + const uint32_t nid = new_idx.outputs()[i].node_id; + Context ctx = arg_grad_ctxes[arg_grad_offset]; + CHECK(ctx == vcontext[nid]) + << "Trying to save gradient to " << ctx + << " while its source node \"" << new_idx[nid].source->attrs.name + << "\" computes it on " << vcontext[nid] + << ". Check your ctx in NDArray allocation."; + } + + g.attrs["context"] = std::make_shared(std::move(vcontext)); + return g; +} } // namespace common } // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 96949a047fba..4a1586044b12 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -675,6 +675,33 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) { return k; } +// helper to initialize an NDArray to all zeros. +static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +// helper to add a NDArray of zeros to a std::vector. +static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index f9c286b596be..f2782356849d 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -56,30 +56,6 @@ GraphExecutor::~GraphExecutor() { } } -inline NDArray GraphExecutor::InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype) { - // NDArray with default storage - if (stype == kDefaultStorage) { - NDArray ret(shape, ctx, false, dtype); - ret = 0; - return ret; - } - // NDArray with non-default storage. Storage allocation is always delayed. - return NDArray(stype, shape, ctx, true, dtype); -} - -inline void GraphExecutor::EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec) { - // NDArray with default storage - if (stype == kDefaultStorage) { - vec->emplace_back(shape, ctx, false, dtype); - vec->back() = 0; - } else { - // NDArray with non-default storage. Storage allocation is always delayed. - vec->emplace_back(stype, shape, ctx, true, dtype); - } -} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -308,204 +284,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, return g; } -/*! - * \brief Assign context to the graph. - * This is triggered by both simple_bind and bind flows. - */ -Graph GraphExecutor::AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - size_t num_forward_inputs, - size_t num_forward_outputs) { - const auto& idx = g.indexed_graph(); - const auto& mutable_nodes = idx.mutable_input_nodes(); - // default use default context. - if (ctx_map.size() == 0) { - g.attrs["context"] = std::make_shared( - ContextVector(idx.num_nodes(), default_ctx)); - for (const auto& x : in_arg_ctxes) { - CHECK(x == default_ctx) - << "Input array is in " << x << " while binding with ctx=" << default_ctx - << ". All arguments must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; - } - for (const auto& x : arg_grad_ctxes) { - CHECK(x == default_ctx) - << "Gradient array is in " << x << " while binding with ctx=" - << default_ctx << ". All gradients must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; - } - return g; - } - - // otherwise, use context assignment. - std::map ctx2id; // map ctx to device id - std::vector ctx_list; // index is device id - nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id - nnvm::DeviceAssignMap device_map; // map arg name to device id - - // loop through the user input ctx_map and - // populate maps and lists - for (auto &kv : ctx_map) { - if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one - ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx - ctx_list.push_back(kv.second); // save ctx to the list - } - // assign device id to to the arg name with the corresponding ctx - device_map[kv.first] = ctx2id.at(kv.second); - } - - // loop through all the rest of input nodes not specified - // in the ctx_map and populate maps and lists - size_t arg_top = 0, aux_top = 0; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - Context ctx; - if (mutable_nodes.count(nid)) { // aux node is mutable - CHECK_LT(aux_top, aux_state_ctxes.size()); - ctx = aux_state_ctxes[aux_top]; - ++aux_top; - } else { // regular input node is immutable - CHECK_LT(arg_top, in_arg_ctxes.size()); - ctx = in_arg_ctxes[arg_top]; - ++arg_top; - } - if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id - ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id - ctx_list.push_back(ctx); // save the current ctx in the list - } - device[nid] = ctx2id.at(ctx); // assign device id to the current node - } - - // loop through backward input nodes and populate maps and lists - // the backward input nodes is the gradient of the loss wrt the output - size_t arg_grad_offset = 0; - // keep an offset into the arg_grad_ctxes vector, - // since g.outputs exclude arg_grad whose req == null - CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) - << "insufficient number of grad_reqs"; - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; - const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; - if (ctx2id.count(ctx) == 0) { - ctx2id[ctx] = static_cast(ctx_list.size()); - ctx_list.push_back(ctx); - } - int devid = ctx2id.at(ctx); - if (device[nid] != -1) { - CHECK_EQ(device[nid], devid) << "device of same output not equal to each other"; - } else { - device[nid] = devid; - } - } - - g.attrs["device"] = std::make_shared(std::move(device)); - g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); - const auto& assigned_device = g.GetAttr("device"); - - ContextVector vcontext; - for (size_t i = 0; i < assigned_device.size(); ++i) { - if (assigned_device[i] == -1) { - vcontext.push_back(default_ctx); - } else { - vcontext.push_back(ctx_list[assigned_device[i]]); - } - } - - // after device planning, we should check again - // if the assigned device of gradient node - // corresponds to storage of grads - auto &new_idx = g.indexed_graph(); - arg_grad_offset = 0; - for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; - const uint32_t nid = new_idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; - CHECK(ctx == vcontext[nid]) - << "Trying to save gradient to " << ctx - << " while its source node \"" << new_idx[nid].source->attrs.name - << "\" computes it on " << vcontext[nid] - << ". Check your ctx in NDArray allocation."; - } - - g.attrs["context"] = std::make_shared(std::move(vcontext)); - return g; -} - -void GraphExecutor::HandleInferShapeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const TShape& inferred_shape = inferred_shapes[eid]; - if (inferred_shape.ndim() == 0 || inferred_shape.Size() == 0U) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << inferred_shape << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments " - "(0s means unknown dimensions). Please consider providing them as inputs:\n" - << oss.str(); -} - -void GraphExecutor::HandleInferTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::DTypeVector& inferred_dtypes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const int inferred_dtype = inferred_dtypes[eid]; - if (inferred_dtype == -1) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << inferred_dtype << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments " - "(-1 means unknown dtype). Please consider providing them as inputs:\n" - << oss.str(); -} - -void GraphExecutor::HandleInferStorageTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const StorageTypeVector& inferred_stypes) { - int cnt = 10; - std::ostringstream oss; - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = idx.input_nodes().at(i); - const uint32_t eid = idx.entry_id(nid, 0); - const int inferred_stype = inferred_stypes[eid]; - if (inferred_stype == -1) { - const std::string& arg_name = idx[nid].source->attrs.name; - oss << arg_name << ": " << common::stype_string(inferred_stype) << ", "; - if (--cnt == 0) { - oss << "..."; - break; - } - } - } - LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments " - "(-1 means unknown stype). Please consider providing them as inputs:\n" - << oss.str(); -} - /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -680,57 +458,6 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, } } -/*! - * \brief If the requested ndarray's shape size is less than - * the corresponding shared_data_array's shape size and the - * storage type is shareable, reuse the memory allocation - * in shared_buffer; otherwise, create a zero ndarray. - * Shareable storages include both default storage and row_sparse storage - * if enable_row_sparse_sharing is `True`, otherwise default storage only. - */ -NDArray GraphExecutor::ReshapeOrCreate(const std::string& name, - const TShape& dest_arg_shape, - const int dest_arg_dtype, - const NDArrayStorageType dest_arg_stype, - const Context& ctx, - std::unordered_map* shared_buffer, - bool enable_row_sparse_sharing) { - bool stype_shareable = dest_arg_stype == kDefaultStorage; - if (enable_row_sparse_sharing) { - stype_shareable = stype_shareable || dest_arg_stype == kRowSparseStorage; - } - auto it = shared_buffer->find(name); - if (it != shared_buffer->end()) { - // check if size is large enough for sharing - bool size_shareable = it->second.shape().Size() >= dest_arg_shape.Size(); - if (size_shareable && stype_shareable) { // memory can be reused - CHECK_EQ(it->second.dtype(), dest_arg_dtype) - << "Requested arg array's dtype does not match that of the reusable ndarray"; - CHECK_EQ(it->second.storage_type(), dest_arg_stype) - << "Requested arg array's stype does not match that of the reusable ndarray"; - return it->second.Reshape(dest_arg_shape); - } else if (stype_shareable) { - LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape - << ", which is larger than already allocated shape " << it->second.shape() - << ". Need to re-allocate. Consider putting default bucket key to be " - << "the bucket taking the largest input for better memory sharing."; - // size is not large enough, creating a larger one for sharing - // the NDArrays in shared_buffer are guaranteed to be of shareable storages - it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - return it->second; - } else { - // not shareable storage - return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - } - } else { - auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); - if (stype_shareable) { - shared_buffer->emplace(name, ret); - } - return ret; - } // if (it != shared_buffer->end()) -} - /*! * \brief Initialize in_args, arg_grads, and aux_states * and their data_entry_ of the executor using diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 05429b2508a5..7b936c300254 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -213,46 +213,8 @@ class GraphExecutor : public Executor { void BulkInferenceOpSegs(); // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); - // prints a helpful message after shape inference errors in executor. - static void HandleInferShapeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::ShapeVector& inferred_shapes); - // prints a helpful message after type inference errors in executor. - static void HandleInferTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const nnvm::DTypeVector& inferred_dtypes); - // prints a helpful message after storage type checking errors in executor. - static void HandleInferStorageTypeError(const size_t num_forward_inputs, - const nnvm::IndexedGraph& idx, - const StorageTypeVector& inferred_stypes); - // helper to initialize an NDArray to all zeros. - static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype); - // helper to add a NDArray of zeros to a std::vector. - static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec); - // helper to reshape an NDArray of certain shape if it doesn't already exist. - static NDArray ReshapeOrCreate(const std::string& name, - const TShape& dest_arg_shape, - const int dest_arg_dtype, - const NDArrayStorageType dest_arg_stype, - const Context& ctx, - std::unordered_map* shared_buffer, - bool enable_row_sparse_sharing); - // Assign context to the graph. - static Graph AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - size_t num_forward_inputs, - size_t num_forward_outputs); // indicate whether there is a backward graph for gradients. bool need_grad_; - // internal graph nnvm::Graph graph_; // operator node From 4b637381affe5c53268dcc75a9ec90e38cfe6b1f Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 22:13:39 +0200 Subject: [PATCH 35/43] Update comments to match util style --- src/common/utils.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/utils.h b/src/common/utils.h index 4a1586044b12..c5fed1348640 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -675,7 +675,9 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) { return k; } -// helper to initialize an NDArray to all zeros. +/*! + * \brief Return an NDArray of all zeros. + */ static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype) { // NDArray with default storage @@ -688,7 +690,9 @@ static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, return NDArray(stype, shape, ctx, true, dtype); } -// helper to add a NDArray of zeros to a std::vector. +/*! + * \brief Helper to add a NDArray of zeros to a std::vector. + */ static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, const Context &ctx, const int dtype, std::vector *vec) { From 694cbfb55649bbc925ed88e97d071420fa2a0da9 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 22:15:45 +0200 Subject: [PATCH 36/43] Apply const to loop element --- src/common/serialization.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/serialization.h b/src/common/serialization.h index 56b6069304d0..8a1bcc6e6ed2 100644 --- a/src/common/serialization.h +++ b/src/common/serialization.h @@ -106,7 +106,7 @@ template inline size_t SerializedSize(const nnvm::Tuple &obj) { if (is_container::value) { size_t sum_val = 4; - for (auto& el : obj) { + for (const auto& el : obj) { sum_val += SerializedSize(el); } return sum_val; From 2be7d257d01045c14cce02ad9223882c242c8b67 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 22:37:45 +0200 Subject: [PATCH 37/43] Fix a few namespace issues --- src/common/exec_utils.h | 56 ++++++++++++++++++++++++------ src/common/utils.h | 31 ----------------- src/executor/graph_executor.cc | 2 ++ src/executor/trt_graph_executor.cc | 2 ++ 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 1757c9499d7b..c5d725bb4718 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -24,11 +24,14 @@ #ifndef MXNET_COMMON_EXEC_UTILS_H_ #define MXNET_COMMON_EXEC_UTILS_H_ +#include +#include #include #include #include #include #include "../common/utils.h" +#include "../executor/exec_pass.h" namespace mxnet { namespace common { @@ -367,6 +370,37 @@ inline void LogInferStorage(const nnvm::Graph& g) { } } +/*! + * \brief Return an NDArray of all zeros. + */ +static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +/*! + * \brief Helper to add a NDArray of zeros to a std::vector. + */ +static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} + // prints a helpful message after shape inference errors in executor. static void HandleInferShapeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, @@ -494,21 +528,21 @@ static NDArray ReshapeOrCreate(const std::string& name, * \brief Assign context to the graph. * This is triggered by both simple_bind and bind flows. */ -static Graph AssignContext(Graph g, - const Context& default_ctx, - const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& arg_grad_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types, - size_t num_forward_inputs, - size_t num_forward_outputs) { +static nnvm::Graph AssignContext(nnvm::Graph g, + const Context& default_ctx, + const std::map& ctx_map, + const std::vector& in_arg_ctxes, + const std::vector& arg_grad_ctxes, + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types, + size_t num_forward_inputs, + size_t num_forward_outputs) { const auto& idx = g.indexed_graph(); const auto& mutable_nodes = idx.mutable_input_nodes(); // default use default context. if (ctx_map.size() == 0) { g.attrs["context"] = std::make_shared( - ContextVector(idx.num_nodes(), default_ctx)); + exec::ContextVector(idx.num_nodes(), default_ctx)); for (const auto& x : in_arg_ctxes) { CHECK(x == default_ctx) << "Input array is in " << x << " while binding with ctx=" << default_ctx @@ -590,7 +624,7 @@ static Graph AssignContext(Graph g, g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_device = g.GetAttr("device"); - ContextVector vcontext; + exec::ContextVector vcontext; for (size_t i = 0; i < assigned_device.size(); ++i) { if (assigned_device[i] == -1) { vcontext.push_back(default_ctx); diff --git a/src/common/utils.h b/src/common/utils.h index c5fed1348640..96949a047fba 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -675,37 +675,6 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) { return k; } -/*! - * \brief Return an NDArray of all zeros. - */ -static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype) { - // NDArray with default storage - if (stype == kDefaultStorage) { - NDArray ret(shape, ctx, false, dtype); - ret = 0; - return ret; - } - // NDArray with non-default storage. Storage allocation is always delayed. - return NDArray(stype, shape, ctx, true, dtype); -} - -/*! - * \brief Helper to add a NDArray of zeros to a std::vector. - */ -static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec) { - // NDArray with default storage - if (stype == kDefaultStorage) { - vec->emplace_back(shape, ctx, false, dtype); - vec->back() = 0; - } else { - // NDArray with non-default storage. Storage allocation is always delayed. - vec->emplace_back(stype, shape, ctx, true, dtype); - } -} - } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index f2782356849d..6810800c8b71 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -37,6 +37,8 @@ namespace mxnet { namespace exec { +using namespace mxnet::common; + GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); need_grad_ = false; diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index a73013b49fca..a68d2d38222a 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -29,6 +29,8 @@ namespace mxnet { namespace exec { +using namespace mxnet::common; + /*! * \brief TrtGraphExecutor initializer for simple bind flow in * which only certain input shapes and dtypes are provided by users. From 369a3f74337677ed65d58dfd6488e662b427799f Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Thu, 9 Aug 2018 23:26:42 +0200 Subject: [PATCH 38/43] Make static funcs inline to avoid compiler warning --- src/common/exec_utils.h | 41 ++++-------------------------- src/common/utils.h | 31 ++++++++++++++++++++++ src/executor/trt_graph_executor.cc | 3 +++ 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index c5d725bb4718..fbe544221a35 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -370,39 +370,8 @@ inline void LogInferStorage(const nnvm::Graph& g) { } } -/*! - * \brief Return an NDArray of all zeros. - */ -static NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype) { - // NDArray with default storage - if (stype == kDefaultStorage) { - NDArray ret(shape, ctx, false, dtype); - ret = 0; - return ret; - } - // NDArray with non-default storage. Storage allocation is always delayed. - return NDArray(stype, shape, ctx, true, dtype); -} - -/*! - * \brief Helper to add a NDArray of zeros to a std::vector. - */ -static void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, - const Context &ctx, const int dtype, - std::vector *vec) { - // NDArray with default storage - if (stype == kDefaultStorage) { - vec->emplace_back(shape, ctx, false, dtype); - vec->back() = 0; - } else { - // NDArray with non-default storage. Storage allocation is always delayed. - vec->emplace_back(stype, shape, ctx, true, dtype); - } -} - // prints a helpful message after shape inference errors in executor. -static void HandleInferShapeError(const size_t num_forward_inputs, +inline void HandleInferShapeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes) { int cnt = 10; @@ -426,7 +395,7 @@ static void HandleInferShapeError(const size_t num_forward_inputs, } // prints a helpful message after type inference errors in executor. -static void HandleInferTypeError(const size_t num_forward_inputs, +inline void HandleInferTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const nnvm::DTypeVector& inferred_dtypes) { int cnt = 10; @@ -450,7 +419,7 @@ static void HandleInferTypeError(const size_t num_forward_inputs, } // prints a helpful message after storage type checking errors in executor. -static void HandleInferStorageTypeError(const size_t num_forward_inputs, +inline void HandleInferStorageTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, const StorageTypeVector& inferred_stypes) { int cnt = 10; @@ -481,7 +450,7 @@ static void HandleInferStorageTypeError(const size_t num_forward_inputs, * Shareable storages include both default storage and row_sparse storage * if enable_row_sparse_sharing is `True`, otherwise default storage only. */ -static NDArray ReshapeOrCreate(const std::string& name, +inline NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, const NDArrayStorageType dest_arg_stype, @@ -528,7 +497,7 @@ static NDArray ReshapeOrCreate(const std::string& name, * \brief Assign context to the graph. * This is triggered by both simple_bind and bind flows. */ -static nnvm::Graph AssignContext(nnvm::Graph g, +inline nnvm::Graph AssignContext(nnvm::Graph g, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_arg_ctxes, diff --git a/src/common/utils.h b/src/common/utils.h index 96949a047fba..fcc3da82b051 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -675,6 +675,37 @@ MSHADOW_XINLINE int ilog2ui(unsigned int a) { return k; } +/*! + * \brief Return an NDArray of all zeros. + */ +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +/*! + * \brief Helper to add a NDArray of zeros to a std::vector. + */ +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/trt_graph_executor.cc b/src/executor/trt_graph_executor.cc index a68d2d38222a..65dbb29792e0 100644 --- a/src/executor/trt_graph_executor.cc +++ b/src/executor/trt_graph_executor.cc @@ -25,6 +25,9 @@ #include #include "./onnx_to_tensorrt.h" #include "../operator/contrib/tensorrt-inl.h" +#include "../common/utils.h" +#include "../common/exec_utils.h" + namespace mxnet { namespace exec { From 1c7698bb92e100580aec9f8e828a2f566de580ff Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 10 Aug 2018 02:23:08 +0200 Subject: [PATCH 39/43] Remove unused inference code from lenet5_train --- tests/python/tensorrt/lenet5_train.py | 40 ++------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/tests/python/tensorrt/lenet5_train.py b/tests/python/tensorrt/lenet5_train.py index 74de66620e88..8edd9abf70e7 100644 --- a/tests/python/tensorrt/lenet5_train.py +++ b/tests/python/tensorrt/lenet5_train.py @@ -16,10 +16,10 @@ # under the License. import os -import numpy as np import mxnet as mx from lenet5_common import get_iters + def lenet5(): """LeNet-5 Symbol""" #pylint: disable=no-member @@ -44,6 +44,7 @@ def lenet5(): #pylint: enable=no-member return lenet + def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter): """train LeNet-5 model on MNIST data""" ctx = mx.gpu(0) @@ -65,44 +66,7 @@ def train_lenet5(num_epochs, batch_size, train_iter, val_iter, test_iter): return lenet_model -def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size): - """Run inference with either MXNet or TensorRT""" - - shared_buffer = merge_dicts(arg_params, aux_params) - if not get_use_tensorrt(): - shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) - executor = sym.simple_bind(ctx=mx.gpu(0), - data=(batch_size,) + mnist['test_data'].shape[1:], - softmax_label=(batch_size,), - shared_buffer=shared_buffer, - grad_req='null', - force_rebind=True) - - # Get this value from all_test_labels - # Also get classes from the dataset - num_ex = 10000 - all_preds = np.zeros([num_ex, 10]) - test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) - - example_ct = 0 - - for idx, dbatch in enumerate(test_iter): - executor.arg_dict["data"][:] = dbatch.data[0] - executor.forward(is_train=False) - offset = idx*batch_size - extent = batch_size if num_ex - offset > batch_size else num_ex - offset - all_preds[offset:offset+extent, :] = executor.outputs[0].asnumpy()[:extent] - example_ct += extent - - all_preds = np.argmax(all_preds, axis=1) - matches = (all_preds[:example_ct] == all_test_labels[:example_ct]).sum() - - percentage = 100.0 * matches / example_ct - - return percentage - if __name__ == '__main__': - num_epochs = 10 batch_size = 128 model_name = 'lenet5' From 74b660321ff9ab5db7987d3d13fb462ada5674a4 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 10 Aug 2018 02:23:50 +0200 Subject: [PATCH 40/43] Add explicit trt contrib bind, update tests to use it --- python/mxnet/contrib/tensorrt.py | 58 +++++++++++++++++++ tests/python/tensorrt/test_cvnets.py | 7 ++- tests/python/tensorrt/test_tensorrt_lenet5.py | 25 ++++---- .../python/tensorrt/test_training_warning.py | 11 +++- 4 files changed, 85 insertions(+), 16 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 11bdecc44868..0124a8966737 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -71,3 +71,61 @@ def get_optimized_symbol(executor): logging.error('Error while trying to fetch TRT optimized symbol for graph. Please ensure ' 'build was compiled with MXNET_USE_TENSORRT enabled.') raise + + +def simple_bind(symbol, all_params, ctx, **kwargs): + """Bind current symbol to get an optimized trt executor. + + Parameters + ---------- + symbol : Symbol + The symbol you wish to bind, and optimize with TensorRT. + + all_params : Dict of str->ndarray + A dictionary of mappings from parameter names to parameter NDArrays. + + ctx : Context + The device context the generated executor to run on. + + grad_req: string + {'write', 'add', 'null'}, or list of str or dict of str to str, optional + To specify how we should update the gradient to the `args_grad`. + + - 'write' means every time gradient is written to specified `args_grad` NDArray. + - 'add' means every time gradient is added to the specified NDArray. + - 'null' means no action is taken, the gradient may not be calculated. This is the only + mode supported by TensorRT + + type_dict : Dict of str->numpy.dtype + Input type dictionary, name->dtype + + stype_dict : Dict of str->str + Input storage type dictionary, name->storage_type + + group2ctx : Dict of string to mx.Context + The dict mapping the `ctx_group` attribute to the context assignment. + + shared_arg_names : List of string + The argument names whose `NDArray` of shared_exec can be reused for initializing + the current executor. + + shared_exec : Executor + The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be + reused for initializing the current executor. + + shared_buffer : Dict of string to `NDArray` + The dict mapping argument names to the `NDArray` that can be reused for initializing + the current executor. This buffer will be checked for reuse if one argument name + of the current executor is not found in `shared_arg_names`. The `NDArray`s are + expected have default storage type. + + kwargs : Dict of str->shape + Input shape dictionary, name->shape + + Returns + ------- + executor : mxnet.Executor + An optimized TensorRT executor. + """ + kwargs['shared_buffer'] = all_params + return symbol.simple_bind(ctx, **kwargs) diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py index 2c64b96fafca..433840dbb151 100644 --- a/tests/python/tensorrt/test_cvnets.py +++ b/tests/python/tensorrt/test_cvnets.py @@ -36,9 +36,10 @@ def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): out = net(data) softmax = mx.sym.SoftmaxOutput(out, name='softmax') all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - executor = softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), - softmax_label=( batch_size,), grad_req='null', - shared_buffer=all_params, force_rebind=True) + executor = mx.contrib.tensorrt.simple_bind(softmax, all_params, ctx=ctx, + data=(batch_size,3, h, w), + softmax_label=(batch_size,), grad_req='null', + force_rebind=True) else: # Convert gluon model to Symbolic net.hybridize() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index 8e2730bcc7d8..a91e25f858c9 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -26,16 +26,21 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz """Run inference with either MXNet or TensorRT""" mx.contrib.tensorrt.set_use_tensorrt(use_tensorrt) - shared_buffer = merge_dicts(arg_params, aux_params) - if not use_tensorrt: - shared_buffer = dict([(k, v.as_in_context(mx.gpu(0))) for k, v in shared_buffer.items()]) - - executor = sym.simple_bind(ctx=mx.gpu(0), - data=(batch_size,) + mnist['test_data'].shape[1:], - softmax_label=(batch_size,), - shared_buffer=shared_buffer, - grad_req='null', - force_rebind=True) + data_size = (batch_size,) + mnist['test_data'].shape[1:] + if use_tensorrt: + all_params = merge_dicts(arg_params, aux_params) + executor = mx.contrib.tensorrt.simple_bind(sym, all_params, ctx=mx.gpu(0), + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) + else: + executor = sym.simple_bind(ctx=mx.gpu(0), + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) + executor.copy_params_from(arg_params, aux_params) # Get this value from all_test_labels # Also get classes from the dataset diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py index 3008a4234b58..a8dd1a465093 100644 --- a/tests/python/tensorrt/test_training_warning.py +++ b/tests/python/tensorrt/test_training_warning.py @@ -49,13 +49,18 @@ def run_resnet(is_train, use_tensorrt): data = mx.sym.var('data') out = resnet(data) softmax = mx.sym.SoftmaxOutput(out, name='softmax') - all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) if is_train: grad_req = 'write' else: grad_req = 'null' - softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), - shared_buffer=all_params, force_rebind=True, grad_req=grad_req) + if use_tensorrt: + all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) + mx.contrib.tensorrt.simple_bind(softmax, all_params, ctx=ctx, + data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) + else: + softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) finally: mx.contrib.tensorrt.set_use_tensorrt(original_trt_value) From 7fff80cc88db60ed2a87f4859ef58f3ca1be533a Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 10 Aug 2018 02:40:59 +0200 Subject: [PATCH 41/43] Rename trt bind call --- python/mxnet/contrib/tensorrt.py | 2 +- tests/python/tensorrt/test_cvnets.py | 8 ++++---- tests/python/tensorrt/test_tensorrt_lenet5.py | 10 +++++----- tests/python/tensorrt/test_training_warning.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index 0124a8966737..a9be7780b3c5 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -73,7 +73,7 @@ def get_optimized_symbol(executor): raise -def simple_bind(symbol, all_params, ctx, **kwargs): +def trt_bind(symbol, all_params, ctx, **kwargs): """Bind current symbol to get an optimized trt executor. Parameters diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py index 433840dbb151..de656b9248a6 100644 --- a/tests/python/tensorrt/test_cvnets.py +++ b/tests/python/tensorrt/test_cvnets.py @@ -36,10 +36,10 @@ def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): out = net(data) softmax = mx.sym.SoftmaxOutput(out, name='softmax') all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - executor = mx.contrib.tensorrt.simple_bind(softmax, all_params, ctx=ctx, - data=(batch_size,3, h, w), - softmax_label=(batch_size,), grad_req='null', - force_rebind=True) + executor = mx.contrib.tensorrt.trt_bind(softmax, all_params, ctx=ctx, + data=(batch_size,3, h, w), + softmax_label=(batch_size,), grad_req='null', + force_rebind=True) else: # Convert gluon model to Symbolic net.hybridize() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index a91e25f858c9..9f6510e37c47 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -29,11 +29,11 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz data_size = (batch_size,) + mnist['test_data'].shape[1:] if use_tensorrt: all_params = merge_dicts(arg_params, aux_params) - executor = mx.contrib.tensorrt.simple_bind(sym, all_params, ctx=mx.gpu(0), - data=data_size, - softmax_label=(batch_size,), - grad_req='null', - force_rebind=True) + executor = mx.contrib.tensorrt.trt_bind(sym, all_params, ctx=mx.gpu(0), + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) else: executor = sym.simple_bind(ctx=mx.gpu(0), data=data_size, diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py index a8dd1a465093..253a5d5c4b7c 100644 --- a/tests/python/tensorrt/test_training_warning.py +++ b/tests/python/tensorrt/test_training_warning.py @@ -55,9 +55,9 @@ def run_resnet(is_train, use_tensorrt): grad_req = 'null' if use_tensorrt: all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) - mx.contrib.tensorrt.simple_bind(softmax, all_params, ctx=ctx, - data=(batch_size, 3, h, w), softmax_label=(batch_size,), - force_rebind=True, grad_req=grad_req) + mx.contrib.tensorrt.trt_bind(softmax, all_params, ctx=ctx, + data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) else: softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), force_rebind=True, grad_req=grad_req) From a754aaba484b5b660b3ee47008e9faf5318c2086 Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 10 Aug 2018 02:53:52 +0200 Subject: [PATCH 42/43] Remove documentation that is not needed for trt --- python/mxnet/contrib/tensorrt.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index a9be7780b3c5..fb8d222b3e2b 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -87,15 +87,6 @@ def trt_bind(symbol, all_params, ctx, **kwargs): ctx : Context The device context the generated executor to run on. - grad_req: string - {'write', 'add', 'null'}, or list of str or dict of str to str, optional - To specify how we should update the gradient to the `args_grad`. - - - 'write' means every time gradient is written to specified `args_grad` NDArray. - - 'add' means every time gradient is added to the specified NDArray. - - 'null' means no action is taken, the gradient may not be calculated. This is the only - mode supported by TensorRT - type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype @@ -105,20 +96,6 @@ def trt_bind(symbol, all_params, ctx, **kwargs): group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. - shared_arg_names : List of string - The argument names whose `NDArray` of shared_exec can be reused for initializing - the current executor. - - shared_exec : Executor - The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be - reused for initializing the current executor. - - shared_buffer : Dict of string to `NDArray` - The dict mapping argument names to the `NDArray` that can be reused for initializing - the current executor. This buffer will be checked for reuse if one argument name - of the current executor is not found in `shared_arg_names`. The `NDArray`s are - expected have default storage type. - kwargs : Dict of str->shape Input shape dictionary, name->shape From 22a3823bcb9ffd55ef28f6b29bdba33f812b018c Mon Sep 17 00:00:00 2001 From: Kellen Sunderland Date: Fri, 10 Aug 2018 03:24:47 +0200 Subject: [PATCH 43/43] Reorder arguments, allow position calling --- python/mxnet/contrib/tensorrt.py | 12 +++++++----- tests/python/tensorrt/test_cvnets.py | 8 ++++---- tests/python/tensorrt/test_tensorrt_lenet5.py | 10 +++++----- tests/python/tensorrt/test_training_warning.py | 6 +++--- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/mxnet/contrib/tensorrt.py b/python/mxnet/contrib/tensorrt.py index fb8d222b3e2b..bb20767b3e6f 100644 --- a/python/mxnet/contrib/tensorrt.py +++ b/python/mxnet/contrib/tensorrt.py @@ -73,7 +73,8 @@ def get_optimized_symbol(executor): raise -def trt_bind(symbol, all_params, ctx, **kwargs): +def tensorrt_bind(symbol, ctx, all_params, type_dict=None, stype_dict=None, group2ctx=None, + **kwargs): """Bind current symbol to get an optimized trt executor. Parameters @@ -81,12 +82,12 @@ def trt_bind(symbol, all_params, ctx, **kwargs): symbol : Symbol The symbol you wish to bind, and optimize with TensorRT. - all_params : Dict of str->ndarray - A dictionary of mappings from parameter names to parameter NDArrays. - ctx : Context The device context the generated executor to run on. + all_params : Dict of str->ndarray + A dictionary of mappings from parameter names to parameter NDArrays. + type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype @@ -105,4 +106,5 @@ def trt_bind(symbol, all_params, ctx, **kwargs): An optimized TensorRT executor. """ kwargs['shared_buffer'] = all_params - return symbol.simple_bind(ctx, **kwargs) + return symbol.simple_bind(ctx, type_dict=type_dict, stype_dict=stype_dict, + group2ctx=group2ctx, **kwargs) diff --git a/tests/python/tensorrt/test_cvnets.py b/tests/python/tensorrt/test_cvnets.py index de656b9248a6..4fdd522341bc 100644 --- a/tests/python/tensorrt/test_cvnets.py +++ b/tests/python/tensorrt/test_cvnets.py @@ -36,10 +36,10 @@ def get_classif_model(model_name, use_tensorrt, ctx=mx.gpu(0), batch_size=128): out = net(data) softmax = mx.sym.SoftmaxOutput(out, name='softmax') all_params = dict([(k, v.data()) for k, v in net.collect_params().items()]) - executor = mx.contrib.tensorrt.trt_bind(softmax, all_params, ctx=ctx, - data=(batch_size,3, h, w), - softmax_label=(batch_size,), grad_req='null', - force_rebind=True) + executor = mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params, + data=(batch_size,3, h, w), + softmax_label=(batch_size,), grad_req='null', + force_rebind=True) else: # Convert gluon model to Symbolic net.hybridize() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index 9f6510e37c47..258686428a45 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -29,11 +29,11 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz data_size = (batch_size,) + mnist['test_data'].shape[1:] if use_tensorrt: all_params = merge_dicts(arg_params, aux_params) - executor = mx.contrib.tensorrt.trt_bind(sym, all_params, ctx=mx.gpu(0), - data=data_size, - softmax_label=(batch_size,), - grad_req='null', - force_rebind=True) + executor = mx.contrib.tensorrt.tensorrt_bind(sym, ctx=mx.gpu(0), all_params=all_params, + data=data_size, + softmax_label=(batch_size,), + grad_req='null', + force_rebind=True) else: executor = sym.simple_bind(ctx=mx.gpu(0), data=data_size, diff --git a/tests/python/tensorrt/test_training_warning.py b/tests/python/tensorrt/test_training_warning.py index 253a5d5c4b7c..fdac859aef6f 100644 --- a/tests/python/tensorrt/test_training_warning.py +++ b/tests/python/tensorrt/test_training_warning.py @@ -55,9 +55,9 @@ def run_resnet(is_train, use_tensorrt): grad_req = 'null' if use_tensorrt: all_params = dict([(k, v.data()) for k, v in resnet.collect_params().items()]) - mx.contrib.tensorrt.trt_bind(softmax, all_params, ctx=ctx, - data=(batch_size, 3, h, w), softmax_label=(batch_size,), - force_rebind=True, grad_req=grad_req) + mx.contrib.tensorrt.tensorrt_bind(softmax, ctx=ctx, all_params=all_params, + data=(batch_size, 3, h, w), softmax_label=(batch_size,), + force_rebind=True, grad_req=grad_req) else: softmax.simple_bind(ctx=ctx, data=(batch_size, 3, h, w), softmax_label=(batch_size,), force_rebind=True, grad_req=grad_req)