Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-703] TensorRT runtime integration #11325

Merged
merged 47 commits into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
268e90b
[MXNET-703] TensorRT runtime integration
Jun 18, 2018
4855d8a
correctly assign self._optimized_symbol in executor
Caenorst Jul 25, 2018
419b294
declare GetTrtCompatibleSubsets and ReplaceSubgraph only if MXNET_USE…
Caenorst Jul 25, 2018
8d723c9
add comments in ReplaceSubgraph
Caenorst Jul 25, 2018
ca94624
Addressing Haibin's code review points
Jul 26, 2018
75c8642
Check that shared_buffer is not empty when USE_TENSORRT is set
Jul 26, 2018
2a11466
Added check that TensorRT binding is for inference only
Jul 26, 2018
190c9bf
Removed redundant decl.
mkolod Aug 2, 2018
d88ad8b
WIP Refactored TRT integration and tests
KellenSunderland Aug 8, 2018
87ebdce
Add more build guards, remove unused code
KellenSunderland Aug 8, 2018
83fa475
Remove ccache report
KellenSunderland Aug 8, 2018
4a3772f
Remove redundant const in declaration
KellenSunderland Aug 8, 2018
f779537
Clean Cmake TRT files
KellenSunderland Aug 8, 2018
b0748ef
Remove TensorRT env var usage
KellenSunderland Aug 8, 2018
35e1367
Use contrib optimize_graph instaed of bind
KellenSunderland Aug 8, 2018
6338e45
Clean up cycle detector
KellenSunderland Aug 8, 2018
21d0239
Convert lenet test to contrib optimize
KellenSunderland Aug 8, 2018
f30dbef
Protect interface with trt build flag
KellenSunderland Aug 8, 2018
eaba593
Fix whitespace issues
KellenSunderland Aug 8, 2018
f870a3f
Add another build guard to c_api
KellenSunderland Aug 8, 2018
e40d6b3
Move get_optimized_symbol to contrib area
KellenSunderland Aug 8, 2018
7fa6a4a
Ignore gz files in test folder
KellenSunderland Aug 9, 2018
e777ab5
Make trt optimization implicit
KellenSunderland Aug 9, 2018
3ea9b89
Remove unused declaration
KellenSunderland Aug 9, 2018
d6d2cac
Replace build guards with runtime errors
KellenSunderland Aug 9, 2018
2d04aee
Change default value of TensorRT to off
KellenSunderland Aug 9, 2018
449a195
Warn user when TRT not active at runtime
KellenSunderland Aug 9, 2018
ed36739
Move TensorRTBind declaration, add descriptive errors
KellenSunderland Aug 9, 2018
882d8e5
Test TensorRT graph execution, fix bugs
KellenSunderland Aug 9, 2018
95a7955
Fix lint and whitespace issues
KellenSunderland Aug 9, 2018
0307467
Fix typo
KellenSunderland Aug 9, 2018
8504319
Removed default value for set_use_tensorrt
KellenSunderland Aug 9, 2018
55dd422
Improved documentation and fixed spacing issues
KellenSunderland Aug 9, 2018
f7ff036
Merge pull request #8 from KellenSunderland/tensorrt_integration_wip
mkolod Aug 9, 2018
ec9d3ea
Move static exec funcs to util files
KellenSunderland Aug 9, 2018
4b63738
Update comments to match util style
KellenSunderland Aug 9, 2018
694cbfb
Apply const to loop element
KellenSunderland Aug 9, 2018
2be7d25
Fix a few namespace issues
KellenSunderland Aug 9, 2018
369a3f7
Make static funcs inline to avoid compiler warning
KellenSunderland Aug 9, 2018
64b7e95
Merge pull request #9 from KellenSunderland/tensorrt_integration_15
mkolod Aug 9, 2018
1c7698b
Remove unused inference code from lenet5_train
KellenSunderland Aug 10, 2018
74b6603
Add explicit trt contrib bind, update tests to use it
KellenSunderland Aug 10, 2018
7fff80c
Rename trt bind call
KellenSunderland Aug 10, 2018
a754aab
Remove documentation that is not needed for trt
KellenSunderland Aug 10, 2018
7ea6ef9
Merge pull request #10 from KellenSunderland/tensorrt_integration_16
mkolod Aug 10, 2018
22a3823
Reorder arguments, allow position calling
KellenSunderland Aug 10, 2018
c3ace78
Merge pull request #11 from KellenSunderland/tensorrt_integration_16
mkolod Aug 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/tvm"]
path = 3rdparty/tvm
url = https://github.com/dmlc/tvm
[submodule "3rdparty/onnx-tensorrt"]
path = 3rdparty/onnx-tensorrt
url = https://github.com/onnx/onnx-tensorrt.git
1 change: 1 addition & 0 deletions 3rdparty/onnx-tensorrt
Submodule onnx-tensorrt added at e7be19
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support"
mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON)
mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF)
mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF)
mxnet_option(USE_TENSORRT "Enable infeference optimization with TensorRT." OFF)

message(STATUS "CMAKE_SYSTEM_NAME ${CMAKE_SYSTEM_NAME}")
if(USE_CUDA AND NOT USE_OLDCMAKECUDA)
Expand Down Expand Up @@ -185,6 +186,15 @@ if(USE_VTUNE)
list(APPEND mxnet_LINKER_LIBS dl)
endif()

if(USE_TENSORRT)
message(STATUS "Using TensorRT")
include_directories(3rdparty/onnx-tensorrt/third_party/onnx/build/)
include_directories(3rdparty/onnx-tensorrt/)
include_directories(3rdparty/)
add_definitions(-DMXNET_USE_TENSORRT=1)
add_definitions(-DONNX_NAMESPACE=onnx)
endif()

if(USE_MKLDNN)
include(cmake/MklDnn.cmake)
# CPU architecture (e.g., C5) can't run on another architecture (e.g., g3).
Expand Down
28 changes: 28 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mx_dist_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3r
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
// timeout in minutes
max_time = 120
// assign any caught errors here
Expand Down Expand Up @@ -372,6 +373,17 @@ try {
}
}
},
'TensorRT': {
node('mxnetlinux-cpu') {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
init_git()
docker_run('ubuntu_gpu_tensorrt', 'build_ubuntu_gpu_tensorrt', false)
pack_lib('tensorrt', mx_tensorrt_lib)
}
}
}
},
'Build CPU windows':{
node('mxnetwindows-cpu') {
timeout(time: max_time, unit: 'MINUTES') {
Expand Down Expand Up @@ -740,6 +752,22 @@ try {
}
}
},
'Python3: TensorRT GPU': {
node('mxnetlinux-gpu-p3') {
ws('workspace/build-tensorrt') {
timeout(time: max_time, unit: 'MINUTES') {
try {
init_git()
unpack_lib('tensorrt', mx_tensorrt_lib)
docker_run('ubuntu_gpu_tensorrt', 'unittest_ubuntu_tensorrt_gpu', true)
publish_test_coverage()
} finally {
collect_test_results_unix('nosetests_tensorrt.xml', 'nosetests_python3_tensorrt_gpu.xml')
}
}
}
}
},
'Scala: CPU': {
node('mxnetlinux-cpu') {
ws('workspace/ut-scala-cpu') {
Expand Down
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ else
endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)


ifeq ($(USE_TENSORRT), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We spoke offline about this, but just a quick note that we should also add the ability to build MXNet-TensorRT integration to our cmake builds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KellenSunderland I agree. Should the CMake build be part of the initial PR or a subsequent one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either way would work.

CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
endif
# -L/usr/local/lib

ifeq ($(DEBUG), 1)
NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
else
Expand Down
14 changes: 7 additions & 7 deletions amalgamation/amalgamation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
import platform

blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
'nvml.h', 'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h',
'omp.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
]
Expand Down Expand Up @@ -150,6 +149,7 @@ def expand(x, pending, stage):
h not in sysheaders and
'mkl' not in h and
'nnpack' not in h and
'tensorrt' not in h and
not h.endswith('.cuh')): sysheaders.append(h)
else:
expand.treeDepth += 1
Expand Down
41 changes: 41 additions & 0 deletions ci/docker/Dockerfile.build.ubuntu_gpu_tensorrt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- mode: dockerfile -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Dockerfile to run MXNet on Ubuntu 16.04 for CPU

FROM nvidia/cuda:9.0-cudnn7-devel

WORKDIR /work/deps

COPY install/ubuntu_core.sh /work/
RUN /work/ubuntu_core.sh
COPY install/deb_ubuntu_ccache.sh /work/
RUN /work/deb_ubuntu_ccache.sh
COPY install/ubuntu_python.sh /work/
RUN /work/ubuntu_python.sh
COPY install/tensorrt.sh /work
RUN /work/tensorrt.sh

ARG USER_ID=0
COPY install/ubuntu_adduser.sh /work/
RUN /work/ubuntu_adduser.sh

COPY runtime_functions.sh /work/

WORKDIR /work/mxnet
ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
45 changes: 45 additions & 0 deletions ci/docker/install/tensorrt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Install gluoncv since we're testing Gluon models as well
pip2 install gluoncv==0.2.0
pip3 install gluoncv==0.2.0

# Install Protobuf
# Install protoc 3.5 and build protobuf here (for onnx and onnx-tensorrt)
pushd .
cd ..
apt-get update
apt-get install -y automake libtool
git clone --recursive -b 3.5.1.1 https://github.com/google/protobuf.git
cd protobuf
./autogen.sh
./configure
make -j$(nproc)
make install
ldconfig
popd

# Install TensorRT
echo "TensorRT build enabled. Installing TensorRT."
wget -qO tensorrt.deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64/nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0_1-1_amd64.deb
dpkg -i tensorrt.deb
apt-get update
apt-get install -y --allow-downgrades libnvinfer-dev
rm tensorrt.deb
67 changes: 67 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,62 @@ build_ubuntu_gpu() {
build_ubuntu_gpu_cuda91_cudnn7
}

build_ubuntu_gpu_tensorrt() {

set -ex

build_ccache_wrappers

# Build ONNX
pushd .
echo "Installing ONNX."
cd 3rdparty/onnx-tensorrt/third_party/onnx
rm -rf build
mkdir -p build
cd build
cmake \
-DCMAKE_CXX_FLAGS=-I/usr/include/python${PYVER}\
-DBUILD_SHARED_LIBS=ON ..\
-G Ninja
ninja -v
export LIBRARY_PATH=`pwd`:`pwd`/onnx/:$LIBRARY_PATH
export CPLUS_INCLUDE_PATH=`pwd`:$CPLUS_INCLUDE_PATH
popd

# Build ONNX-TensorRT
pushd .
cd 3rdparty/onnx-tensorrt/
mkdir -p build
cd build
cmake ..
make -j$(nproc)
export LIBRARY_PATH=`pwd`:$LIBRARY_PATH
popd

mkdir -p /work/mxnet/lib/
cp 3rdparty/onnx-tensorrt/third_party/onnx/build/*.so /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser_runtime.so.0 /work/mxnet/lib/
cp -L 3rdparty/onnx-tensorrt/build/libnvonnxparser.so.0 /work/mxnet/lib/

rm -rf build
make \
DEV=1 \
USE_BLAS=openblas \
USE_CUDA=1 \
USE_CUDA_PATH=/usr/local/cuda \
USE_CUDNN=1 \
USE_OPENCV=0 \
USE_DIST_KVSTORE=0 \
USE_TENSORRT=1 \
USE_JEMALLOC=0 \
USE_GPERFTOOLS=0 \
ONNX_NAMESPACE=onnx \
CUDA_ARCH="-gencode arch=compute_70,code=compute_70"\
-j$(nproc)

report_ccache_usage
}

build_ubuntu_gpu_mkldnn() {
set -ex

Expand Down Expand Up @@ -638,6 +694,15 @@ unittest_ubuntu_python3_gpu_nocudnn() {
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
}

unittest_ubuntu_tensorrt_gpu() {
set -ex
export PYTHONPATH=./python/
export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
python tests/python/tensorrt/lenet5_train.py
nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
}

# quantization gpu currently only runs on P3 instances
# need to separte it from unittest_ubuntu_python2_gpu()
unittest_ubuntu_python2_quantization_gpu() {
Expand Down Expand Up @@ -970,3 +1035,5 @@ EOF
declare -F | cut -d' ' -f3
echo
fi


7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,13 @@ MXNET_DLL int MXExecutorReshape(int partial_shaping,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec,
ExecutorHandle *out);

/*!
* \brief get optimized graph from graph executor
*/
MXNET_DLL int MXExecutorGetOptimizedSymbol(ExecutorHandle handle,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to expose this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong See my reply to @eric-haibin-lin here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong What is your take now that you have the context?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to expose it as a private member of executor

SymbolHandle *out);

/*!
* \brief set a call back to notify the completion of operation
*/
Expand Down
16 changes: 8 additions & 8 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ class Executor {
static Executor* SimpleBind(nnvm::Symbol symbol,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks API backward compatibility. You can define another simple bind function such as "SimpleBineEx" to achieve your purpose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce: I'm not sure that I'd consider this a breaking API change. My understanding from what was discussed the dev list was that we would follow the hour-glass model for the c_api, and make the c_api the only native interface that we semantically version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KellenSunderland You got a valid point. My worry was that since this function is declared in the header placed in the include directory, we have no idea whether users have used this function or not to build something for their own.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce Yes I do see what you mean. It's certainly possible someone used the header. As long as it's not too much work would you mind creating a second function @mkolod?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to maintain compatibility on this. Its unlikely any user would depend on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @piiswrong !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mkolod As we discussed offline about not breaking existing APIs, could you create a simple bind API for you to use only, rather than modifying this one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change reference to pointer?

Copy link
Contributor Author

@mkolod mkolod Jul 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong I think we already discussed this here. The TensorRT graph pass does a graph rewrite, and that requires mutable types. Going from const ref to ref causes the linter to fail. The linter suggestion is to switch from const refs to pointers if mutability is required. You mentioned here that this was acceptable from an API stability point of view, because no changes in the C API are necessary, only in the C++ one.

Copy link
Contributor Author

@mkolod mkolod Jul 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong The method that requires mutability is GraphExecutor::ReinitGraph, which can be found here.

The call path here is SimpleBind() -> Init() -> ReinitGraph() (see here).

This doesn't pose any breaking changes to the existing code base, as long as the C API is used to handle frontend bindings.

Copy link
Contributor Author

@mkolod mkolod Jul 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong Also, I presume the lines in question are 155-162, not 152. Those are the ones changed by this commit, lines 152-154 and 163-168 are as authored on 2017-06-02 by @reminisce.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong Thoughts about the above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass by value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think it's better to name the functions as InitTensorRT rather than reinitgraph

const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<Context>* in_arg_ctxes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what is the purpose of modifying these arguments by passing pointers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce Because if things are to be mutated, they need to be pointers, not non-const references (per the linter rules). Given your earlier comments about SimpleBindEx rather than modifying SimpleBind, this will be addressed there rather than modifying it here.

std::vector<Context>* arg_grad_ctxes,
std::vector<Context>* aux_state_ctxes,
std::unordered_map<std::string, TShape>* arg_shape_map,
std::unordered_map<std::string, int>* arg_dtype_map,
std::unordered_map<std::string, int>* arg_stype_map,
std::vector<OpReqType>* grad_req_types,
std::unordered_set<std::string>* param_names,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grads,
std::vector<NDArray>* aux_states,
Expand Down
16 changes: 16 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,19 @@ def write_all_str(module_file, module_all_list):
module_op_file.close()
write_all_str(module_internal_file, module_internal_all)
module_internal_file.close()

def cint(init_val=0):
"""create a C int with an optional initial value"""
return C.c_int(init_val)

def int_addr(x):
"""given a c_int, return it's address as an int ptr"""
x_addr = C.addressof(x)
int_p = C.POINTER(C.c_int)
x_int_addr = C.cast(x_addr, int_p)
return x_int_addr

def checked_call(f, *args):
"""call a cuda function and check for success"""
error_t = f(*args)
assert error_t == 0, "Failing cuda call %s returns %s." % (f.__name__, error_t)
19 changes: 18 additions & 1 deletion python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import ctypes
import copy
import numpy as np
import mxnet as mx
from .base import _LIB
from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
from .base import mx_uint, NDArrayHandle, ExecutorHandle, SymbolHandle, py_str
from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, handle, symbol, ctx, grad_req, group2ctx):
self.aux_arrays = []
self.outputs = self._get_outputs()
self._symbol = copy.deepcopy(symbol)
self._optimized_symbol = None
self._arg_dict = None
self._grad_dict = None
self._aux_dict = None
Expand Down Expand Up @@ -323,6 +325,21 @@ def output_dict(self):
self._symbol.list_outputs(), self.outputs)
return self._output_dict

@property
def optimized_symbol(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would an user want to access optimized_symbols? Is this added only for testing purpose? Does it make sense to keep it private?

Copy link
Contributor Author

@mkolod mkolod Jul 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin For now this is added for testing only, but it's actually quite useful, and not just for the TensorRT integration, but potentially for other acceleration libraries in the future. Imagine wanting to visualize a model to figure out what the acceleration library did. Looking at the timeline may be helpful, but only if one knows which kernels belong to the accelerator, and which to MXNet or direct cuBLAS/cuDNN calls. Seeing how the graph got rewritten can help determine which layers could not be handled by the acceleration library, to determine what asks to make of libraries going forward. For example, TensorRT doesn't handle non-maximum suppression (NMS) directly, it can do that via a plugin. This initial integration doesn't handle TensorRT plugins. Knowing what didn't get subsumed into the accelerator graph node and comparing against a profile can determine whether for instance layers that are not being handled by the acceleration library subgraph are dominating the profile - a visual inspection makes this really quick. This could be useful for both the Graphviz integration, as well as potentially for TensorBoard. We tested this with GraphViz already, and it's working fine. Check out the attached visualization of LeNet-5. The graph on the left is the original MXNet one, and the one on the right is with all layers except for the input tensor having been replaced by a single TensorRT node. For a more complex network such as SSD, it would be visible that the NMS layer is a separate NNVM node, outside of the TensorRT subgraph.

Basically, once you get the optimized symbol, you can also call

mx.viz.plot_network(optimized_symbol)

as you would for the original symbol.

trt

"""Get optimized symbol.

Returns
-------
symbol : nnvm::Symbol
The nnvm symbol optimized.
"""
if self._optimized_symbol is None:
handle = SymbolHandle()
check_call(_LIB.MXExecutorGetOptimizedSymbol(self.handle, ctypes.byref(handle)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if _optimized_symbol exist?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a mistake indeed, _optimized_symbol is never modified it would never exist. Corrected

self._optimized_symbol = mx.sym.Symbol(handle=handle)
return self._optimized_symbol

def copy_params_from(self, arg_params, aux_params=None, allow_extra_params=False):
"""Copy parameters from arg_params, aux_params into executor's internal array.

Expand Down
Loading