diff --git a/CMakeLists.txt b/CMakeLists.txt
index 6f8c33b6a23d..896c7b75a1ec 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -290,6 +290,16 @@ if(USE_CUDA)
message(WARNING "Could not find NCCL libraries")
endif()
endif()
+ if(UNIX)
+ find_package(NVTX)
+ if(NVTX_FOUND)
+ include_directories(${NVTX_INCLUDE_DIRS})
+ list(APPEND mxnet_LINKER_LIBS ${NVTX_LIBRARIES})
+ add_definitions(-DMXNET_USE_NVTX=1)
+ else()
+ message(WARNING "Could not find NVTX libraries")
+ endif()
+ endif()
else()
add_definitions(-DMSHADOW_USE_CUDA=0)
endif()
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a847a9ec4462..f8873ce2752a 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -239,6 +239,7 @@ List of Contributors
* [Zak Jost](https://github.com/zjost)
* [Shoubhik Bhattacharya](https://github.com/shoubhik)
* [Zach Kimberg](https://github.com/zachgk)
+* [Rohit Srivastava](https://github.com/access2rohit)
* [Caner Turkmen](https://github.com/canerturkmen)
Label Bot
diff --git a/Makefile b/Makefile
index df0fe8809456..b578683a74b6 100644
--- a/Makefile
+++ b/Makefile
@@ -106,6 +106,11 @@ ifeq ($(ENABLE_TESTCOVERAGE), 1)
LDFLAGS += --coverage
endif
+ifeq ($(USE_NVTX), 1)
+ CFLAGS += -DMXNET_USE_NVTX=1
+ LDFLAGS += -lnvToolsExt
+endif
+
ifeq ($(USE_TENSORRT), 1)
CFLAGS += -I$(ROOTDIR) -I$(TPARTYDIR) -DONNX_NAMESPACE=$(ONNX_NAMESPACE) -DMXNET_USE_TENSORRT=1
LDFLAGS += -lprotobuf -pthread -lonnx -lonnx_proto -lnvonnxparser -lnvonnxparser_runtime -lnvinfer -lnvinfer_plugin
diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index e47ab6b0e22e..fef54aaecba4 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -30,7 +30,7 @@
'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
- 'relacy_shims.h', 'ittnotify.h', 'shared_mutex'
+ 'relacy_shims.h', 'ittnotify.h', 'shared_mutex', 'nvToolsExt.h'
]
minimum = int(sys.argv[6]) if len(sys.argv) > 5 else 0
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu b/ci/docker/Dockerfile.build.ubuntu_gpu_cu100
similarity index 100%
rename from ci/docker/Dockerfile.build.ubuntu_gpu
rename to ci/docker/Dockerfile.build.ubuntu_gpu_cu100
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu80 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu80
new file mode 100644
index 000000000000..9c7a8084b093
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu80
@@ -0,0 +1,78 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for GPU
+
+FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+
+COPY install/ubuntu_scala.sh /work/
+COPY install/sbt.gpg /work/
+RUN /work/ubuntu_scala.sh
+
+COPY install/ubuntu_r.sh /work/
+COPY install/r.gpg /work/
+RUN /work/ubuntu_r.sh
+
+COPY install/ubuntu_perl.sh /work/
+RUN /work/ubuntu_perl.sh
+
+COPY install/ubuntu_clang.sh /work/
+RUN /work/ubuntu_clang.sh
+
+COPY install/ubuntu_mklml.sh /work/
+RUN /work/ubuntu_mklml.sh
+
+COPY install/ubuntu_tvm.sh /work/
+RUN /work/ubuntu_tvm.sh
+
+COPY install/ubuntu_llvm.sh /work/
+RUN /work/ubuntu_llvm.sh
+
+COPY install/ubuntu_caffe.sh /work/
+RUN /work/ubuntu_caffe.sh
+
+COPY install/ubuntu_onnx.sh /work/
+RUN /work/ubuntu_onnx.sh
+
+COPY install/ubuntu_docs.sh /work/
+COPY install/docs_requirements /work/
+RUN /work/ubuntu_docs.sh
+
+COPY install/ubuntu_tutorials.sh /work/
+RUN /work/ubuntu_tutorials.sh
+
+ARG USER_ID=0
+ARG GROUP_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu90 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu90
new file mode 100644
index 000000000000..f1e6570f03b9
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu90
@@ -0,0 +1,83 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for GPU
+
+FROM nvidia/cuda:9.0-devel-ubuntu16.04
+
+ENV CUDNN_VERSION=7.3.1.20
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+
+COPY install/ubuntu_scala.sh /work/
+COPY install/sbt.gpg /work/
+RUN /work/ubuntu_scala.sh
+
+COPY install/ubuntu_r.sh /work/
+COPY install/r.gpg /work/
+RUN /work/ubuntu_r.sh
+
+COPY install/ubuntu_perl.sh /work/
+RUN /work/ubuntu_perl.sh
+
+COPY install/ubuntu_clang.sh /work/
+RUN /work/ubuntu_clang.sh
+
+COPY install/ubuntu_mklml.sh /work/
+RUN /work/ubuntu_mklml.sh
+
+COPY install/ubuntu_tvm.sh /work/
+RUN /work/ubuntu_tvm.sh
+
+COPY install/ubuntu_llvm.sh /work/
+RUN /work/ubuntu_llvm.sh
+
+COPY install/ubuntu_caffe.sh /work/
+RUN /work/ubuntu_caffe.sh
+
+COPY install/ubuntu_onnx.sh /work/
+RUN /work/ubuntu_onnx.sh
+
+COPY install/ubuntu_docs.sh /work/
+COPY install/docs_requirements /work/
+RUN /work/ubuntu_docs.sh
+
+COPY install/ubuntu_tutorials.sh /work/
+RUN /work/ubuntu_tutorials.sh
+
+ARG USER_ID=0
+ARG GROUP_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY install/ubuntu_cudnn.sh /work/
+RUN /work/ubuntu_cudnn.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu92 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu92
new file mode 100644
index 000000000000..81b337e4d9a7
--- /dev/null
+++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu92
@@ -0,0 +1,83 @@
+# -*- mode: dockerfile -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# Dockerfile to run MXNet on Ubuntu 16.04 for GPU
+
+FROM nvidia/cuda:9.2-devel-ubuntu16.04
+
+ENV CUDNN_VERSION=7.3.1.20
+
+WORKDIR /work/deps
+
+COPY install/ubuntu_core.sh /work/
+RUN /work/ubuntu_core.sh
+
+COPY install/deb_ubuntu_ccache.sh /work/
+RUN /work/deb_ubuntu_ccache.sh
+
+COPY install/ubuntu_python.sh /work/
+RUN /work/ubuntu_python.sh
+
+COPY install/ubuntu_scala.sh /work/
+COPY install/sbt.gpg /work/
+RUN /work/ubuntu_scala.sh
+
+COPY install/ubuntu_r.sh /work/
+COPY install/r.gpg /work/
+RUN /work/ubuntu_r.sh
+
+COPY install/ubuntu_perl.sh /work/
+RUN /work/ubuntu_perl.sh
+
+COPY install/ubuntu_clang.sh /work/
+RUN /work/ubuntu_clang.sh
+
+COPY install/ubuntu_mklml.sh /work/
+RUN /work/ubuntu_mklml.sh
+
+COPY install/ubuntu_tvm.sh /work/
+RUN /work/ubuntu_tvm.sh
+
+COPY install/ubuntu_llvm.sh /work/
+RUN /work/ubuntu_llvm.sh
+
+COPY install/ubuntu_caffe.sh /work/
+RUN /work/ubuntu_caffe.sh
+
+COPY install/ubuntu_onnx.sh /work/
+RUN /work/ubuntu_onnx.sh
+
+COPY install/ubuntu_docs.sh /work/
+COPY install/docs_requirements /work/
+RUN /work/ubuntu_docs.sh
+
+COPY install/ubuntu_tutorials.sh /work/
+RUN /work/ubuntu_tutorials.sh
+
+ARG USER_ID=0
+ARG GROUP_ID=0
+COPY install/ubuntu_adduser.sh /work/
+RUN /work/ubuntu_adduser.sh
+
+COPY install/ubuntu_cudnn.sh /work/
+RUN /work/ubuntu_cudnn.sh
+
+COPY runtime_functions.sh /work/
+
+WORKDIR /work/mxnet
+ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib
diff --git a/ci/docker/install/ubuntu_cudnn.sh b/ci/docker/install/ubuntu_cudnn.sh
index 12b64865a219..3f649074f4d5 100755
--- a/ci/docker/install/ubuntu_cudnn.sh
+++ b/ci/docker/install/ubuntu_cudnn.sh
@@ -24,6 +24,31 @@
set -ex
+if [ -z ${CUDNN_VERSION} ]; then
+ echo "Error: CUDNN_VERSION environment variable undefiend"
+ exit 1
+fi
+
apt-get update || true
-apt-get install -y libcudnn7=7.3.1.20-1+cuda10.0 libcudnn7-dev=7.3.1.20-1+cuda10.0
+
+case ${CUDA_VERSION} in
+ 10\.0*)
+ export libcudnn7_version="${CUDNN_VERSION}-1+cuda10.0"
+ export libcudnn7_dev_version="${CUDNN_VERSION}-1+cuda10.0"
+ ;;
+ 9\.0*)
+ export libcudnn7_version="${CUDNN_VERSION}-1+cuda9.0"
+ export libcudnn7_dev_version="${CUDNN_VERSION}-1+cuda9.0"
+ ;;
+ 9\.2*)
+ export libcudnn7_version="${CUDNN_VERSION}-1+cuda9.2"
+ export libcudnn7_dev_version="${CUDNN_VERSION}-1+cuda9.2"
+ ;;
+ *)
+ echo "Unsupported CUDA version ${CUDA_VERSION}"
+ exit 1
+ ;;
+esac
+
+apt-get install -y libcudnn7=${libcudnn7_version} libcudnn7-dev=${libcudnn7_dev_version}
diff --git a/ci/docker/install/ubuntu_tutorials.sh b/ci/docker/install/ubuntu_tutorials.sh
index 60adf46e6d8a..4e40426ed85c 100755
--- a/ci/docker/install/ubuntu_tutorials.sh
+++ b/ci/docker/install/ubuntu_tutorials.sh
@@ -23,5 +23,7 @@
set -ex
apt-get update || true
apt-get install graphviz python-opencv
-pip2 install jupyter matplotlib Pillow opencv-python scikit-learn graphviz tqdm mxboard scipy
+
+# sckit-learn past version 0.20 does not support python version 2 and 3.4
+pip2 install jupyter matplotlib Pillow opencv-python "scikit-learn<0.21.0" graphviz tqdm mxboard scipy
pip3 install jupyter matplotlib Pillow opencv-python scikit-learn graphviz tqdm mxboard scipy
diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy
index 1352e92c785f..2c86cf740a80 100644
--- a/ci/jenkins/Jenkins_steps.groovy
+++ b/ci/jenkins/Jenkins_steps.groovy
@@ -139,7 +139,7 @@ def compile_unix_int64_gpu() {
ws('workspace/build-gpu-int64') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
- utils.docker_run('ubuntu_gpu', 'build_ubuntu_gpu_large_tensor', false)
+ utils.docker_run('ubuntu_gpu_cu100', 'build_ubuntu_gpu_large_tensor', false)
utils.pack_lib('ubuntu_gpu_int64', mx_cmake_lib, true)
}
}
@@ -237,7 +237,7 @@ def compile_unix_cmake_mkldnn_gpu() {
ws('workspace/build-cmake-mkldnn-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
- utils.docker_run('ubuntu_gpu', 'build_ubuntu_gpu_cmake_mkldnn', false)
+ utils.docker_run('ubuntu_gpu_cu100', 'build_ubuntu_gpu_cmake_mkldnn', false)
utils.pack_lib('cmake_mkldnn_gpu', mx_cmake_mkldnn_lib, true)
}
}
@@ -251,7 +251,7 @@ def compile_unix_cmake_gpu() {
ws('workspace/build-cmake-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
- utils.docker_run('ubuntu_gpu', 'build_ubuntu_gpu_cmake', false)
+ utils.docker_run('ubuntu_gpu_cu100', 'build_ubuntu_gpu_cmake', false)
utils.pack_lib('cmake_gpu', mx_cmake_lib, true)
}
}
@@ -606,7 +606,7 @@ def test_unix_python2_gpu() {
ws('workspace/ut-python2-gpu') {
try {
utils.unpack_and_init('gpu', mx_lib, true)
- python2_gpu_ut('ubuntu_gpu')
+ python2_gpu_ut('ubuntu_gpu_cu100')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python2_gpu.xml')
@@ -623,7 +623,7 @@ def test_unix_python2_quantize_gpu() {
timeout(time: max_time, unit: 'MINUTES') {
try {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_ubuntu_python2_quantization_gpu', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_ubuntu_python2_quantization_gpu', true)
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_quantization_gpu.xml', 'nosetests_python2_quantize_gpu.xml')
@@ -640,7 +640,7 @@ def test_unix_python2_mkldnn_gpu() {
ws('workspace/ut-python2-mkldnn-gpu') {
try {
utils.unpack_and_init('mkldnn_gpu', mx_mkldnn_lib, true)
- python2_gpu_ut('ubuntu_gpu')
+ python2_gpu_ut('ubuntu_gpu_cu100')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python2_mkldnn_gpu.xml')
@@ -690,7 +690,7 @@ def test_unix_python3_gpu() {
ws('workspace/ut-python3-gpu') {
try {
utils.unpack_and_init('gpu', mx_lib, true)
- python3_gpu_ut('ubuntu_gpu')
+ python3_gpu_ut('ubuntu_gpu_cu100')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_gpu.xml')
@@ -707,7 +707,7 @@ def test_unix_python3_quantize_gpu() {
timeout(time: max_time, unit: 'MINUTES') {
try {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_ubuntu_python3_quantization_gpu', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_ubuntu_python3_quantization_gpu', true)
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_quantization_gpu.xml', 'nosetests_python3_quantize_gpu.xml')
@@ -792,7 +792,7 @@ def test_unix_python3_mkldnn_gpu() {
ws('workspace/ut-python3-mkldnn-gpu') {
try {
utils.unpack_and_init('mkldnn_gpu', mx_mkldnn_lib, true)
- python3_gpu_ut('ubuntu_gpu')
+ python3_gpu_ut('ubuntu_gpu_cu100')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_mkldnn_gpu.xml')
@@ -808,7 +808,7 @@ def test_unix_python3_mkldnn_nocudnn_gpu() {
ws('workspace/ut-python3-mkldnn-gpu-nocudnn') {
try {
utils.unpack_and_init('mkldnn_gpu_nocudnn', mx_mkldnn_lib, true)
- python3_gpu_ut_nocudnn('ubuntu_gpu')
+ python3_gpu_ut_nocudnn('ubuntu_gpu_cu100')
utils.publish_test_coverage()
} finally {
utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_mkldnn_gpu_nocudnn.xml')
@@ -842,7 +842,7 @@ def test_unix_python3_integration_gpu() {
ws('workspace/it-python-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'integrationtest_ubuntu_gpu_python', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'integrationtest_ubuntu_gpu_python', true)
utils.publish_test_coverage()
}
}
@@ -857,7 +857,7 @@ def test_unix_caffe_gpu() {
timeout(time: max_time, unit: 'MINUTES') {
utils.init_git()
utils.unpack_lib('gpu', mx_lib)
- utils.docker_run('ubuntu_gpu', 'integrationtest_ubuntu_gpu_caffe', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'integrationtest_ubuntu_gpu_caffe', true)
utils.publish_test_coverage()
}
}
@@ -871,7 +871,7 @@ def test_unix_cpp_package_gpu() {
ws('workspace/it-cpp-package') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib_cpp_examples, true)
- utils.docker_run('ubuntu_gpu', 'integrationtest_ubuntu_gpu_cpp_package', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'integrationtest_ubuntu_gpu_cpp_package', true)
utils.publish_test_coverage()
}
}
@@ -913,7 +913,7 @@ def test_unix_scala_gpu() {
ws('workspace/ut-scala-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'integrationtest_ubuntu_gpu_scala', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'integrationtest_ubuntu_gpu_scala', true)
utils.publish_test_coverage()
}
}
@@ -996,7 +996,7 @@ def test_unix_cpp_gpu() {
ws('workspace/ut-cpp-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('cmake_gpu', mx_cmake_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_cpp', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_cpp', true)
utils.publish_test_coverage()
}
}
@@ -1010,7 +1010,7 @@ def test_unix_cpp_mkldnn_gpu() {
ws('workspace/ut-cpp-mkldnn-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('cmake_mkldnn_gpu', mx_cmake_mkldnn_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_cpp', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_cpp', true)
utils.publish_test_coverage()
}
}
@@ -1038,7 +1038,7 @@ def test_unix_perl_gpu() {
ws('workspace/ut-perl-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_ubuntu_cpugpu_perl', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_ubuntu_cpugpu_perl', true)
utils.publish_test_coverage()
}
}
@@ -1052,7 +1052,7 @@ def test_unix_r_gpu() {
ws('workspace/ut-r-gpu') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'unittest_ubuntu_gpu_R', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'unittest_ubuntu_gpu_R', true)
utils.publish_test_coverage()
}
}
@@ -1120,7 +1120,7 @@ def test_unix_distributed_kvstore_gpu() {
ws('workspace/it-dist-kvstore') {
timeout(time: max_time, unit: 'MINUTES') {
utils.unpack_and_init('gpu', mx_lib, true)
- utils.docker_run('ubuntu_gpu', 'integrationtest_ubuntu_gpu_dist_kvstore', true)
+ utils.docker_run('ubuntu_gpu_cu100', 'integrationtest_ubuntu_gpu_dist_kvstore', true)
utils.publish_test_coverage()
}
}
diff --git a/cmake/Modules/FindNVTX.cmake b/cmake/Modules/FindNVTX.cmake
new file mode 100644
index 000000000000..bf05eaeb092c
--- /dev/null
+++ b/cmake/Modules/FindNVTX.cmake
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set(NVTX_ROOT_DIR "" CACHE PATH "Folder contains NVIDIA NVTX")
+
+find_path(NVTX_INCLUDE_DIRS
+ NAMES nvToolsExt.h
+ PATHS $ENV{NVTOOLSEXT_PATH} ${NVTX_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
+ PATH_SUFFIXES include
+ )
+
+find_library(NVTX_LIBRARIES
+ NAMES nvToolsExt64_1.lib nvToolsExt32_1.lib nvToolsExt
+ PATHS $ENV{NVTOOLSEXT_PATH} ${NVTX_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}
+ PATH_SUFFIXES lib lib64 lib/Win32 lib/x64
+ )
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NVTX DEFAULT_MSG NVTX_INCLUDE_DIRS NVTX_LIBRARIES)
+
+if(NVTX_FOUND)
+ message(STATUS "Found NVTX (include: ${NVTX_INCLUDE_DIRS}, library: ${NVTX_LIBRARIES})")
+ mark_as_advanced(NVTX_ROOT_DIR NVTX_INCLUDE_DIRS NVTX_LIBRARIES)
+endif()
diff --git a/cpp-package/example/charRNN.cpp b/cpp-package/example/charRNN.cpp
index ac5faa47b58c..94e9455c5941 100644
--- a/cpp-package/example/charRNN.cpp
+++ b/cpp-package/example/charRNN.cpp
@@ -164,8 +164,9 @@ Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_d
auto rnn_h_init = Symbol::Variable("LSTM_init_h");
auto rnn_c_init = Symbol::Variable("LSTM_init_c");
auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class
- auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, num_hidden, num_lstm_layer,
- RNNMode::kLstm, false, dropout, !isTrain);
+ auto variable_sequence_length = Symbol::Variable("sequence_length");
+ auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden,
+ num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain);
auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);
auto cls_weight = Symbol::Variable("cls_weight");
diff --git a/dev_menu.py b/dev_menu.py
index d439d8194f2a..0b5db2505671 100755
--- a/dev_menu.py
+++ b/dev_menu.py
@@ -32,10 +32,12 @@
import yaml
import shutil
+
DEFAULT_PYENV=os.environ.get('DEFAULT_PYENV','py3_venv')
DEFAULT_PYTHON=os.environ.get('DEFAULT_PYTHON','python3')
DEFAULT_CMAKE_OPTIONS=os.environ.get('DEFAULT_CMAKE_OPTIONS','cmake_options.yml')
+
class Confirm(object):
def __init__(self, cmds):
self.cmds = cmds
@@ -51,6 +53,7 @@ def __call__(self):
else:
resp = input("Please answer yes or no: ")
+
class CMake(object):
def __init__(self, cmake_options_yaml=DEFAULT_CMAKE_OPTIONS, cmake_options_yaml_default='cmake/cmake_options.yml'):
if os.path.exists(cmake_options_yaml):
@@ -93,27 +96,38 @@ def __call__(self, build_dir='build', generator='Ninja', build_cmd='ninja'):
logging.info('Now building')
check_call(shlex.split(build_cmd))
+
def create_virtualenv(venv_exe, pyexe, venv) -> None:
logging.info("Creating virtualenv in %s with python %s", venv, pyexe)
if not (venv_exe and pyexe and venv):
logging.warn("Skipping creation of virtualenv")
return
check_call([venv_exe, '-p', pyexe, venv])
- activate_this_py = os.path.join(venv, 'bin', 'activate_this.py')
- # Activate virtualenv in this interpreter
- exec(open(activate_this_py).read(), dict(__file__=activate_this_py))
- check_call(['pip', 'install', '--upgrade','--force-reinstall', '-e', 'python'])
- check_call(['pip', 'install', '-r', 'tests/requirements.txt'])
+
def create_virtualenv_default():
create_virtualenv('virtualenv', DEFAULT_PYTHON, DEFAULT_PYENV)
logging.info("You can use the virtualenv by executing 'source %s/bin/activate'", DEFAULT_PYENV)
+
+def provision_virtualenv(venv_path=DEFAULT_PYENV):
+ pip = os.path.join(venv_path, 'bin', 'pip')
+ if os.path.exists(pip):
+ # Install MXNet python bindigs
+ check_call([pip, 'install', '--upgrade', '--force-reinstall', '-e', 'python'])
+ # Install test dependencies
+ check_call([pip, 'install', '--upgrade', '--force-reinstall', '-r', os.path.join('tests',
+ 'requirements.txt')])
+ else:
+ logging.warn("Can't find pip: '%s' not found", pip)
+
+
COMMANDS = OrderedDict([
('[Local] BUILD CMake/Ninja (using cmake_options.yaml (cp cmake/cmake_options.yml .) and edit) ({} virtualenv in "{}")'.format(DEFAULT_PYTHON, DEFAULT_PYENV),
[
CMake(),
create_virtualenv_default,
+ provision_virtualenv,
]),
('[Local] Python Unit tests',
"./py3_venv/bin/nosetests -v tests/python/unittest/"
@@ -209,7 +223,8 @@ def build(args) -> None:
else:
cmake = CMake()
cmake()
- create_virtualenv(venv_exe, pyexe, args.venv)
+ create_virtualenv_default()
+ provision_virtualenv()
def main():
logging.getLogger().setLevel(logging.INFO)
diff --git a/docs/api/python/profiler/profiler.md b/docs/api/python/profiler/profiler.md
index 565495e27c02..c025811f5a5d 100644
--- a/docs/api/python/profiler/profiler.md
+++ b/docs/api/python/profiler/profiler.md
@@ -2,7 +2,7 @@
## Overview
-MXNet has a built-in profiler which is compatibule with both Intel® VTune™ Amplifier as well as Chrome's chrome://tracing visualization engine. When built witht he USE_VTUNE=1 flag, MXNet makes actual VTune API calls to define Domains, Frames, Tasks, Events Counters, and Markers. For a detailed explanation of these, see [Instrumentation and Tracing Technology API Reference ](https://software.intel.com/en-us/vtune-amplifier-help-instrumentation-and-tracing-technology-api-reference)
+MXNet has a built-in profiler which is compatible with Intel® VTune™ Amplifier, NVIDIA NVTX and Chrome's chrome://tracing visualization engine. When built with the USE_VTUNE=1 flag, MXNet makes VTune API calls to define Domains, Frames, Tasks, Events Counters, and Markers. For a detailed explanation of these, see [Instrumentation and Tracing Technology API Reference ](https://software.intel.com/en-us/vtune-amplifier-help-instrumentation-and-tracing-technology-api-reference). When built with CUDA NVTX ranges will be inserted into any profiles generated, which can subsequently be viewed view NVProf.
```eval_rst
.. autosummary::
@@ -34,7 +34,7 @@ MXNet has a built-in profiler which is compatibule with both Intel® VTune™ Am
### Profiling Objects
-These profiling objects can be created and accessed from python in order to resord performance information of the python code paths
+These profiling objects can be created and accessed from python in order to record performance information of the python code paths.
```eval_rst
.. autosummary::
diff --git a/docs/tutorials/python/profiler.md b/docs/tutorials/python/profiler.md
index fe7611aa538f..d3e3355b8f4a 100644
--- a/docs/tutorials/python/profiler.md
+++ b/docs/tutorials/python/profiler.md
@@ -185,7 +185,7 @@ MXNet executes computation graphs in 'bulk mode' which reduces kernel launch gap
### Viewing profiler output
-There are two ways to view the information collected by the profiler. You can either view it in the console or you can view a more graphical version in a browser.
+There are a few ways to view the information collected by the profiler. You can view it in the console, you can view a more graphical version in a browser, or you can use a vendor tool such as Intel VTune or Nvidia NVProf to view output. For most scenarios the information you need can be obtained with MXNet's built in profiler support, but if you want to investigate the performance of operators along side extra context about your hardware (e.g. cache hit rates, or CUDA kernel timings) then profiling jointly with vendor tools is recommended.
#### 1. View in console
@@ -215,6 +215,29 @@ Let's zoom in to check the time taken by operators
The above picture visualizes the sequence in which the operators were executed and the time taken by each operator.
+#### 3. View in NVProf
+
+You can view all MXNet profiler information alongside CUDA kernel information by using the MXNet profiler along with NVProf. Use the MXNet profiler as in the samples above, but invoke your python script with the following wrapper process available on most systems that support CUDA:
+
+```bash
+nvprof -o my_profile.nvvp python my_profiler_script.py
+==11588== NVPROF is profiling process 11588, command: python my_profiler_script.py
+==11588== Generated result file: /home/kellen/Development/incubator-mxnet/ci/my_profile.nvvp
+```
+Your my_profile.nvvp file will automatically be annotated with NVTX ranges displayed alongside your standard NVProf timeline. This can be very useful when you're trying to find patterns between operators run by MXNet, and their associated CUDA kernel calls.
+
+![Operator profiling](profiler_nvprof.png)
+
+In this picture we see a rough overlay of a few types of information plotted on a horizontal timeline. At the top of the plot we have CPU tasks such as driver operations, memory copy calls, MXNet engine operator invocations, and imperative MXNet API calls. Below we see the kernels active on the GPU during the same time period.
+
+![Operator profiling](profiler_nvprof_zoomed.png)
+
+Zooming in on a backwards convolution operator we can see that it is in fact made up of a number of different GPU kernel calls, including a cuDNN winograd convolution call, and a fast-fourier transform call.
+
+![Operator profiling](profiler_winograd.png)
+
+Selecting any of these kernel calls (the winograd convolution call shown here) will get you some interesting GPU performance information such as occupancy rates (vs theoretical), shared memory usage and execution duration.
+
### Further reading
- [Examples using MXNet profiler.](https://github.com/apache/incubator-mxnet/tree/master/example/profiler)
diff --git a/docs/tutorials/python/profiler_nvprof.png b/docs/tutorials/python/profiler_nvprof.png
new file mode 100644
index 000000000000..37d8615c2b54
Binary files /dev/null and b/docs/tutorials/python/profiler_nvprof.png differ
diff --git a/docs/tutorials/python/profiler_nvprof_zoomed.png b/docs/tutorials/python/profiler_nvprof_zoomed.png
new file mode 100644
index 000000000000..9b6b6e88e93d
Binary files /dev/null and b/docs/tutorials/python/profiler_nvprof_zoomed.png differ
diff --git a/docs/tutorials/python/profiler_winograd.png b/docs/tutorials/python/profiler_winograd.png
new file mode 100644
index 000000000000..5b4fcc3155fb
Binary files /dev/null and b/docs/tutorials/python/profiler_winograd.png differ
diff --git a/make/config.mk b/make/config.mk
index 20834675ecbd..2080a016572f 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -80,6 +80,9 @@ ENABLE_CUDA_RTC = 1
# whether use CuDNN R3 library
USE_CUDNN = 0
+# whether to use NVTX when profiling
+USE_NVTX = 0
+
#whether to use NCCL library
USE_NCCL = 0
#add the path to NCCL library
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 6dfec43a8b5f..b3cc596282a7 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
- dtype, **kwargs):
+ dtype, use_sequence_length=False, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
@@ -58,6 +58,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype
+ self._use_sequence_length = use_sequence_length
self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
@@ -219,29 +220,39 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states
- def hybrid_forward(self, F, inputs, states=None, **kwargs):
- if F is ndarray:
- batch_size = inputs.shape[self._layout.find('N')]
- skip_states = states is None
- if skip_states:
- if F is ndarray:
+ def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
+ self.skip_states = states is None
+ if states is None:
+ if isinstance(inputs, ndarray.NDArray):
+ batch_size = inputs.shape[self._layout.find('N')]
states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]
+
+ if self._use_sequence_length:
+ return super(_RNNLayer, self).__call__(inputs, states, sequence_length, **kwargs)
+ else:
+ return super(_RNNLayer, self).__call__(inputs, states, **kwargs)
+
+
+ def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
+ if F is ndarray:
+ batch_size = inputs.shape[self._layout.find('N')]
+
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
- out = self._forward_kernel(F, inputs, states, **kwargs)
+ out = self._forward_kernel(F, inputs, states, sequence_length, **kwargs)
# out is (output, state)
- return out[0] if skip_states else out
+ return out[0] if self.skip_states else out
- def _forward_kernel(self, F, inputs, states, **kwargs):
+ def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
@@ -261,14 +272,20 @@ def _forward_kernel(self, F, inputs, states, **kwargs):
params = F._internal._rnn_param_concat(*params, dim=0)
- rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
- projection_size=self._projection_size,
+ if self._use_sequence_length:
+ rnn_args = states + [sequence_length]
+ else:
+ rnn_args = states
+
+ rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
+ state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)
+
if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 3764f5a4a040..40888012fc5a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -741,10 +741,16 @@ object NDArray extends NDArrayBase {
*
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
- val writable: Boolean = true,
- addToCollector: Boolean = true) extends NativeResource {
- if (addToCollector) {
- NDArrayCollector.collect(this)
+ val writable: Boolean) extends NativeResource {
+
+ @deprecated("Please use ResourceScope instead", "1.5.0")
+ def this(handle: NDArrayHandle,
+ writable: Boolean = true,
+ addToCollector: Boolean = true) {
+ this(handle, writable)
+ if (addToCollector) {
+ NDArrayCollector.collect(this)
+ }
}
override def nativeAddress: CPtrAddress = handle
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
index 0b7f9af705f1..0761481cdfe8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
@@ -64,6 +64,7 @@ import scala.collection.mutable
* });
*
*/
+@deprecated("Please use ResourceScope instead", "1.5.0")
object NDArrayCollector {
private val logger = LoggerFactory.getLogger(classOf[NDArrayCollector])
@@ -75,12 +76,14 @@ object NDArrayCollector {
* Create a collector which will dispose the collected NDArrays automatically.
* @return an auto-disposable collector.
*/
+ @deprecated("Please use ResourceScope instead", "1.5.0")
def auto(): NDArrayCollector = new NDArrayCollector(true)
/**
* Create a collector allows users to later dispose the collected NDArray manually.
* @return a manually-disposable collector.
*/
+ @deprecated("Please use ResourceScope instead", "1.5.0")
@Experimental
def manual(): NDArrayCollector = new NDArrayCollector(false)
@@ -88,11 +91,13 @@ object NDArrayCollector {
* Collect the NDArrays into the collector of the current thread.
* @param ndArray NDArrays need to be collected.
*/
+ @deprecated("Please use ResourceScope instead", "1.5.0")
@varargs def collect(ndArray: NDArray*): Unit = {
currCollector.get().add(ndArray: _*)
}
}
+@deprecated("Please use ResourceScope instead", "1.5.0")
class NDArrayCollector private(private val autoDispose: Boolean = true,
private val doCollect: Boolean = true) {
// native ptr (handle) of the NDArray -> NDArray
@@ -142,6 +147,7 @@ class NDArrayCollector private(private val autoDispose: Boolean = true,
* @return The result of function codeBlock.
*/
@Experimental
+ @deprecated("Please use ResourceScope instead", "1.5.0")
def withScope[T](codeBlock: => T): T = {
val old = NDArrayCollector.currCollector.get()
NDArrayCollector.currCollector.set(this)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index b955c185b6d1..9a822870749e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -63,6 +63,14 @@ class ResourceScope extends AutoCloseable {
resource.scope = Some(this)
}
+ /**
+ * Check if a NativeResource is in the scope
+ * @param resource
+ */
+ def contains(resource: NativeResource): Boolean = {
+ resourceQ.contains(resource)
+ }
+
/**
* Remove NativeResource from the Scope, this uses
* object equality to find the resource in the stack.
@@ -80,8 +88,10 @@ class ResourceScope extends AutoCloseable {
def moveToOuterScope(resource: NativeResource): Unit = {
val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
if (prevScope.isDefined) {
- this.remove(resource)
- prevScope.get.add(resource)
+ if (contains(resource)) {
+ this.remove(resource)
+ prevScope.get.add(resource)
+ }
} else this.remove(resource)
}
@@ -109,20 +119,16 @@ object ResourceScope {
val curScope = if (scope != null) scope else new ResourceScope()
- @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
- g.foreach( n =>
- n match {
- case nRes: NativeResource => {
- curScope.moveToOuterScope(nRes)
- }
- case kv: scala.Tuple2[_, _] => {
- if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
- kv._1.asInstanceOf[NativeResource])
- if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
- kv._2.asInstanceOf[NativeResource])
- }
- }
- )
+ def recursiveMoveToOuterScope(resource: Any): Unit = {
+ resource match {
+ case nRes: NativeResource => curScope.moveToOuterScope(nRes)
+ case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
+ case resInGeneric: scala.collection.Traversable[_] =>
+ resInGeneric.foreach(recursiveMoveToOuterScope)
+ case resProduct: scala.Product =>
+ resProduct.productIterator.foreach(recursiveMoveToOuterScope)
+ case _ => // do nothing
+ }
}
@inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = {
@@ -133,13 +139,7 @@ object ResourceScope {
try {
val ret = body
- ret match {
- // don't de-allocate if returning any collection that contains NativeResource.
- case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric)
- case nRes: NativeResource => curScope.moveToOuterScope(nRes)
- case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => curScope.moveToOuterScope(nd) )
- case _ => // do nothing
- }
+ recursiveMoveToOuterScope(ret)
ret
} catch {
case t: Throwable =>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index e9513257c050..cda181823205 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -177,7 +177,7 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1, ndArray.shape.size)
// The new NDArray has to be created such that it inherits dtype from the passed in array
val newArray = NDArray.zeros(shape, dtype = ndArray.dtype)
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val batch = ndArray.slice(cursor, numData)
val padding = ndArray.slice(0, padNum)
newArray.slice(0, dataBatchSize - padNum).set(batch)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala
index dc2faec9dd91..fde9bdbc0abf 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmark.scala
@@ -50,7 +50,7 @@ object ScalaInferenceBenchmark {
List[Long] = {
var inferenceTimes: List[Long] = List()
for (i <- 1 to totalRuns) {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val startTimeSingle = System.currentTimeMillis()
objectToRun.runSingleInference(loadedModel, dataSet)
val estimatedTimeSingle = System.currentTimeMillis() - startTimeSingle
@@ -67,7 +67,7 @@ object ScalaInferenceBenchmark {
var inferenceTimes: List[Long] = List()
for (batch <- dataSetBatches) {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val loadedBatch = objecToRun.loadInputBatch(batch)
val startTimeSingle = System.currentTimeMillis()
objecToRun.runBatchInference(loadedModel, loadedBatch)
@@ -133,7 +133,7 @@ object ScalaInferenceBenchmark {
logger.info("Running single inference call")
// Benchmarking single inference call
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val loadedModel = loadModel(exampleToBenchmark, context, false)
val dataSet = loadDataSet(exampleToBenchmark)
val inferenceTimes = runInference(exampleToBenchmark, loadedModel, dataSet, baseCLI.count)
@@ -143,7 +143,7 @@ object ScalaInferenceBenchmark {
if (baseCLI.batchSize != 0) {
logger.info("Running for batch inference call")
// Benchmarking batch inference call
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val loadedModel = loadModel(exampleToBenchmark, context, true)
val batchDataSet = loadBatchDataSet(exampleToBenchmark, baseCLI.batchSize)
val inferenceTimes = runBatchInference(exampleToBenchmark, loadedModel, batchDataSet)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
index d9902e9dcc75..04cc6e240bc2 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
@@ -18,7 +18,7 @@
package org.apache.mxnetexamples.cnntextclassification
import org.apache.mxnet.optimizer.RMSProp
-import org.apache.mxnet.{Context, Executor, Model, NDArray, NDArrayCollector, Optimizer, Shape, Symbol, Uniform}
+import org.apache.mxnet.{Context, Executor, Model, NDArray, Optimizer, ResourceScope, Shape, Symbol, Uniform}
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
@@ -131,7 +131,7 @@ object CNNTextClassification {
numTotal = 0f
updateRate = 0
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
for (begin <- 0 until trainBatches.length by batchSize) {
val (batchD, batchL) = {
if (begin + batchSize <= trainBatches.length) {
@@ -239,7 +239,7 @@ object CNNTextClassification {
def test(w2vFilePath : String, mrDatasetPath: String,
ctx : Context, saveModelPath: String) : Float = {
- val output = NDArrayCollector.auto().withScope {
+ val output = ResourceScope.using() {
val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
mrDatasetPath, numEmbed, word2vec)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
index 0cfcc49aee04..42922f212c11 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.customop
import org.apache.mxnet.Callback.Speedometer
import org.apache.mxnet.DType.DType
-import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, NDArrayCollector, Operator, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Accuracy, Context, CustomOp, CustomOpProp, NDArray, Operator, ResourceScope, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.RMSProp
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
@@ -141,7 +141,7 @@ object ExampleCustomOp {
evalMetric.reset()
var nBatch = 0
var epochDone = false
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
trainIter.reset()
while (!epochDone) {
var doReset = true
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 475d91faa0dc..8b312c621758 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -17,7 +17,7 @@
package org.apache.mxnetexamples.gan
-import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, NDArrayCollector, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, ResourceScope, Shape, Symbol, Xavier}
import org.apache.mxnet.optimizer.Adam
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
@@ -104,7 +104,7 @@ object GanMnist {
def runTraining(dataPath : String, context : Context,
outputPath : String, numEpoch : Int): Float = {
- val output = NDArrayCollector.auto().withScope {
+ val output = ResourceScope.using() {
val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
@@ -147,7 +147,7 @@ object GanMnist {
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
index c7f2fdac30c3..48e55004cf7b 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
@@ -49,7 +49,7 @@ object ImageClassifierExample {
def runInferenceOnSingleImage(modelPathPrefix: String, inputImagePath: String,
context: Array[Context]):
IndexedSeq[IndexedSeq[(String, Float)]] = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val dType = DType.Float32
val inputShape = Shape(1, 3, 224, 224)
@@ -71,7 +71,7 @@ object ImageClassifierExample {
def runInferenceOnBatchOfImage(modelPathPrefix: String, inputImageDir: String,
context: Array[Context]):
IndexedSeq[IndexedSeq[(String, Float)]] = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val dType = DType.Float32
val inputShape = Shape(1, 3, 224, 224)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
index 07d1cc82e927..8c5366d6279a 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
@@ -51,7 +51,7 @@ object SSDClassifierExample {
def runObjectDetectionSingle(modelPathPrefix: String, inputImagePath: String,
context: Array[Context]):
IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val dType = DType.Float32
val inputShape = Shape(1, 3, 512, 512)
// ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
@@ -68,7 +68,7 @@ object SSDClassifierExample {
def runObjectDetectionBatch(modelPathPrefix: String, inputImageDir: String,
context: Array[Context]):
IndexedSeq[IndexedSeq[(String, Array[Float])]] = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val dType = DType.Float32
val inputShape = Shape(1, 3, 512, 512)
// ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 5c17a3747ab6..e406c6d21d23 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -25,7 +25,7 @@ import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, Executor, NDArray, NDArrayCollector, Shape, Symbol, Xavier}
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, Executor, NDArray, ResourceScope, Shape, Symbol, Xavier}
import org.apache.mxnet.DType.DType
import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnetexamples.Util
@@ -222,7 +222,7 @@ object ExampleMultiTask {
def train(batchSize: Int, numEpoch: Int, ctx: Context, modelDirPath: String):
(Executor, MultiAccuracy) = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val lr = 0.001f
val network = ExampleMultiTask.buildNetwork()
val (trainIter, valIter) =
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
index 475e179f819b..beb80ced9d4e 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
@@ -170,7 +170,7 @@ object NeuralStyle {
contentWeight : Float, tvWeight : Float, gaussianRadius : Int,
lr: Float, maxNumEpochs: Int, maxLongEdge: Int,
saveEpochs : Int, stopEps: Float) : Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
val size = (contentNp.shape(2), contentNp.shape(3))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
index b1e6634db80e..cd1ed59b6e6d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
@@ -17,7 +17,7 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.apache.mxnet.{Context, NDArrayCollector, Shape}
+import org.apache.mxnet.{Context, ResourceScope, Shape}
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
@@ -29,7 +29,7 @@ object BoostInference {
def runInference(modelPath: String, outputPath: String, guassianRadius : Int,
inputImage : String, ctx : Context): Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val dShape = Shape(1, 3, 480, 640)
val clipNorm = 1.0f * dShape.product
// generator
@@ -47,7 +47,7 @@ object BoostInference {
DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx)
var data = Array(contentNp)
for (i <- 0 until gens.length) {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
gens(i).forward(data.takeRight(1))
val newImg = gens(i).getOutputs()(0)
data :+= newImg
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
index 8246f44bae2f..1c9adbaf7560 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.neuralstyle.end2end
import java.io.File
-import org.apache.mxnet.{Context, Executor, NDArray, NDArrayCollector, Shape, Symbol}
+import org.apache.mxnet.{Context, Executor, NDArray, ResourceScope, Shape, Symbol}
import org.apache.mxnet.optimizer.SGD
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
@@ -56,7 +56,7 @@ object BoostTrain {
def runTraining(dataPath : String, vggModelPath: String, ctx : Context,
styleImage : String, saveModelPath : String) : Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
// params
val vggParams = NDArray.load2Map(vggModelPath)
val styleWeight = 1.2f
@@ -117,7 +117,7 @@ object BoostTrain {
// train
for (i <- startEpoch until endEpoch) {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
filelist = Random.shuffle(filelist)
for (idx <- filelist.indices) {
var dataArray = Array[NDArray]()
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
index 8b2059d2e119..d1a70a755b01 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/LstmBucketing.scala
@@ -62,7 +62,7 @@ object LstmBucketing {
def runTraining(trainData : String, validationData : String,
ctx : Array[Context], numEpoch : Int): Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val batchSize = 32
val buckets = Array(10, 20, 30, 40, 50, 60)
val numHidden = 200
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
index bf2eba660388..750fd9837e53 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TestCharRnn.scala
@@ -34,7 +34,7 @@ object TestCharRnn {
private val logger = LoggerFactory.getLogger(classOf[TrainCharRnn])
def runInferenceCharRNN(dataPath: String, modelPrefix: String, starterSentence : String): Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
// The batch size for training
val batchSize = 32
// We can support various length input
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index 68346afe1f47..2704715b0c4d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -33,7 +33,7 @@ object TrainCharRnn {
def runTrainCharRnn(dataPath: String, saveModelPath: String,
ctx : Context, numEpoch : Int): Unit = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
// The batch size for training
val batchSize = 32
// We can support various length input
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
index ae0ee33002d9..0424c1262835 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
@@ -21,7 +21,7 @@ import java.io.File
import java.net.URL
import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
index 709ea77632e0..f6872aedfe69 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.gan
import java.io.File
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.slf4j.LoggerFactory
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
index c93a7d06a452..5b1cbc525890 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
@@ -17,7 +17,7 @@
package org.apache.mxnetexamples.neuralstyle
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.apache.mxnetexamples.neuralstyle.end2end.{BoostInference, BoostTrain}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
index ff3fbe9e05d2..ca62f484ac20 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -18,7 +18,7 @@
package org.apache.mxnetexamples.rnn
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.slf4j.LoggerFactory
diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc
index 8f72bc259afc..335856356534 100644
--- a/src/ndarray/ndarray_function.cc
+++ b/src/ndarray/ndarray_function.cc
@@ -207,7 +207,11 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s,
using namespace mxnet::op::mxnet_op;
const TBlob& out_data = out->data();
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type
- Kernel::Launch(s, out_data.Size(), out_data.dptr());
+ // Do not set_zero when output mem inplace with input[0] mem
+ // Now for add_n OP, output mem can be in-placed with the first input
+ if (nds[0].data().dptr() != out_data.dptr()) {
+ Kernel::Launch(s, out_data.Size(), out_data.dptr());
+ }
for (size_t i = 0; i < nds.size(); ++i) {
const NDArray& nd = nds[i];
const TBlob& nd_data = nd.data();
diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc
index cb57b1b2d16f..971ff6ad560b 100644
--- a/src/operator/nn/upsampling.cc
+++ b/src/operator/nn/upsampling.cc
@@ -121,9 +121,56 @@ struct UpSamplingGrad {
DMLC_REGISTER_PARAMETER(UpSamplingParam);
NNVM_REGISTER_OP(UpSampling)
-.describe("Performs nearest neighbor/bilinear up sampling to inputs. "
- "Bilinear upsampling makes use of deconvolution. Therefore, "
- "provide 2 inputs - data and weight. ")
+.describe(R"code(Upsamples the given input data.
+
+Two algorithms (``sample_type``) are available for upsampling:
+
+- Nearest Neighbor
+- Bilinear
+
+**Nearest Neighbor Upsampling**
+
+Input data is expected to be NCHW.
+
+Example::
+
+ x = [[[[1. 1. 1.]
+ [1. 1. 1.]
+ [1. 1. 1.]]]]
+
+ UpSampling(x, scale=2, sample_type='nearest') = [[[[1. 1. 1. 1. 1. 1.]
+ [1. 1. 1. 1. 1. 1.]
+ [1. 1. 1. 1. 1. 1.]
+ [1. 1. 1. 1. 1. 1.]
+ [1. 1. 1. 1. 1. 1.]
+ [1. 1. 1. 1. 1. 1.]]]]
+
+**Bilinear Upsampling**
+
+Uses `deconvolution` algorithm under the hood. You need provide both input data and the kernel.
+
+Input data is expected to be NCHW.
+
+`num_filter` is expected to be same as the number of channels.
+
+Example::
+
+ x = [[[[1. 1. 1.]
+ [1. 1. 1.]
+ [1. 1. 1.]]]]
+
+ w = [[[[1. 1. 1. 1.]
+ [1. 1. 1. 1.]
+ [1. 1. 1. 1.]
+ [1. 1. 1. 1.]]]]
+
+ UpSampling(x, w, scale=2, sample_type='bilinear', num_filter=1) = [[[[1. 2. 2. 2. 2. 1.]
+ [2. 4. 4. 4. 4. 2.]
+ [2. 4. 4. 4. 4. 2.]
+ [2. 4. 4. 4. 4. 2.]
+ [2. 4. 4. 4. 4. 2.]
+ [1. 2. 2. 2. 2. 1.]]]]
+)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const UpSamplingParam& params = nnvm::get(attrs.parsed);
return params.sample_type == up_enum::kNearest ? params.num_args : 2;
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 37f21ce6d126..d164333953f2 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -27,7 +27,7 @@
#define MXNET_OPERATOR_RNN_INL_H_
#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
-#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
+#define MXNET_USE_CUDNN_GE_7200 MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
#include
#include
@@ -39,6 +39,7 @@
#include
#include
#include
+
#include "./math.h"
#include "./math_functions-inl.h"
#include "./operator_common.h"
@@ -48,10 +49,10 @@ namespace mxnet {
namespace op {
namespace rnn_enum {
- enum RNNOpInputs {kData, kParams, kState, kStateCell};
+ enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength};
enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
- enum RNNOpResource {kCuDNNDropoutDescSpace};
+ enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace};
}
inline int GetRnnParamSize(int num_layer,
@@ -166,6 +167,8 @@ struct RNNParam : public dmlc::Parameter {
int mode;
float p;
int seq_length_, batch_size_, input_size_;
+
+ bool use_sequence_length;
dmlc::optional projection_size;
dmlc::optional lstm_state_clip_min, lstm_state_clip_max;
bool lstm_state_clip_nan;
@@ -212,9 +215,22 @@ struct RNNParam : public dmlc::Parameter {
.set_default(false)
.describe("Whether to stop NaN from propagating in state by clipping it to min/max. "
"If clipping range is not specified, this option is ignored.");
+
+ DMLC_DECLARE_FIELD(use_sequence_length)
+ .set_default(false)
+ .describe(
+ "If set to true, this layer takes in an extra input parameter "
+ "`sequence_length` "
+ "to specify variable length sequence");
}
};
+inline size_t GetNumInputArguments(RNNParam param_) {
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U;
+ if (param_.use_sequence_length) num_inputs += 1U;
+ return num_inputs;
+}
+
/**
* @params: ws: Temp workspace for gemm's output storage.
* rs: Reserve space of forward intermediate data used for training.
@@ -379,7 +395,7 @@ void RNNBackward(DType* ws,
}
}
-template
+template
class RNNOp {
public:
RNNParam param_;
@@ -415,7 +431,7 @@ class RNNOp {
default:
LOG(FATAL) << "Not implmented";
}
-#if USE_CUDNN_LSTM_PROJ
+#if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
<< "Projection is only supported for LSTM.";
@@ -426,7 +442,7 @@ class RNNOp {
CHECK(!param_.projection_size.has_value())
<< "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
#endif
-#if USE_CUDNN_LSTM_PROJ
+#if MXNET_USE_CUDNN_GE_7200
if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
CHECK_EQ(param_.mode, rnn_enum::kLstm)
@@ -459,7 +475,7 @@ class RNNOp {
CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
@@ -515,7 +531,7 @@ class RNNOp {
Storage::Get()->Free(temp_space_);
Storage::Get()->Free(reserve_space_);
}
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
@@ -541,8 +557,9 @@ class RNNOp {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
- << "unsupported dropout value, should be 0 <= dropout < 1";
- size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ << "unsupported dropout value, should be 0 <= dropout < 1";
+ size_t num_inputs = GetNumInputArguments(param_);
+
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
@@ -553,6 +570,7 @@ class RNNOp {
CHECK_EQ(in_data.size(), num_inputs);
CHECK_EQ(out_data.size(), num_outputs);
Stream *s = ctx.get_stream();
+
// get input + output tensors
Tensor x = in_data[rnn_enum::kData].get(s);
Tensor w = in_data[rnn_enum::kParams].get(s);
@@ -562,6 +580,7 @@ class RNNOp {
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
+
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
DType* b_ptr = w.dptr_ + w.shape_[0] - bsize;
@@ -570,65 +589,130 @@ class RNNOp {
if (param_.state_outputs) {
hy_ptr = out_data[rnn_enum::kStateOut].dptr();
}
+
+
+#if MXNET_USE_CUDNN_GE_7200
+ Tensor host_workspace;
+ int *sequence_length_cpu_int = NULL;
+ IType *sequence_length_cpu_itype = NULL;
+
+ if (ctx_.dev_type == kGPU) {
+ int host_workspace_bytes =
+ param_.batch_size_ * sizeof(IType) + param_.batch_size_ * sizeof(int);
+
+ host_workspace =
+ ctx.requested[rnn_enum::kTempSpace].get_host_space_typed<1, char>(
+ Shape1(host_workspace_bytes));
+
+ sequence_length_cpu_int = reinterpret_cast(host_workspace.dptr_);
+ sequence_length_cpu_itype =
+ reinterpret_cast(host_workspace.dptr_ + sizeof(int) * param_.batch_size_);
+
+ (void)sequence_length_cpu_int;
+ (void)sequence_length_cpu_itype;
+ }
+#endif
+
+
+ if (param_.use_sequence_length) {
+#if MXNET_USE_CUDNN_GE_7200
+ if (ctx_.dev_type == kCPU) {
+ LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN at the moment."
+ << " Not supported on CPU";
+ }
+
+ // We can assume we are on GPU for now
+ size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+ if (param_.mode != rnn_enum::kLstm) {
+ seq_len_input_idx -= 1;
+ }
+ IType *sequence_length_ptr_gpu = (in_data[seq_len_input_idx].get(s)).dptr_;
+
+ // Need to copy from GPU -> CPU, becuase cuDNN API requires this array on CPU memory.
+ // TODO(stephenrawls): In future, allow users to pass this array on the CPU so we don't have
+ // to do this copy For now however it is required as several places in backend assume that
+ // all data arrays share the same context.
+ CUDA_CALL(cudaMemcpy(sequence_length_cpu_itype, sequence_length_ptr_gpu,
+ sizeof(IType) * param_.batch_size_, cudaMemcpyDeviceToHost));
+#else
+ LOG(FATAL) << "RNN use_sequence_length option is only available for cuDNN version >= 7.2";
+#endif
+ }
DType* cx_ptr = NULL;
DType* cy_ptr = NULL;
- if (param_.mode == rnn_enum::kLstm)
+ if (param_.mode == rnn_enum::kLstm) {
cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_;
- if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+ }
+ if (param_.mode == rnn_enum::kLstm && param_.state_outputs) {
cy_ptr = (out_data[rnn_enum::kStateCellOut].get(s)).dptr_;
-
+ }
CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
- #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+#if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
if (!init_cudnn_) {
Init(ctx, s, in_data, out_data);
}
- #if USE_CUDNN_LSTM_PROJ
- std::vector seqLengthArray(param_.batch_size_, param_.seq_length_);
+#if MXNET_USE_CUDNN_GE_7200
+
+ cudnnRNNDataLayout_t layout_t;
+
+ if (param_.use_sequence_length) {
+ // Note: Can't mempcy, sequence_length_ptr_cpu is of type Itype, not nescesarily int
+ for (int i = 0; i < param_.batch_size_; ++i) {
+ sequence_length_cpu_int[i] = sequence_length_cpu_itype[i];
+ }
+ layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
+ } else {
+ for (int i = 0; i < param_.batch_size_; ++i) {
+ sequence_length_cpu_int[i] = param_.seq_length_;
+ }
+ layout_t = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED;
+ }
+
CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
dtype_,
- CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ layout_t,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
- seqLengthArray.data(),
- nullptr));
+ sequence_length_cpu_int,
+ reinterpret_cast(&padding_fill_)));
int out_size =
(param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
dtype_,
- CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ layout_t,
param_.seq_length_,
param_.batch_size_,
out_size,
- seqLengthArray.data(),
- nullptr));
+ sequence_length_cpu_int,
+ reinterpret_cast(&padding_fill_)));
if (ctx.is_train) {
CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
dtype_,
- CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ layout_t,
param_.seq_length_,
param_.batch_size_,
param_.input_size_,
- seqLengthArray.data(),
- nullptr));
+ sequence_length_cpu_int,
+ reinterpret_cast(&padding_fill_)));
CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
dtype_,
- CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ layout_t,
param_.seq_length_,
param_.batch_size_,
out_size,
- seqLengthArray.data(),
- nullptr));
+ sequence_length_cpu_int,
+ reinterpret_cast(&padding_fill_)));
}
- #endif
+#endif
- #if USE_CUDNN_LSTM_PROJ
+#if MXNET_USE_CUDNN_GE_7200
bool clip_state = param_.lstm_state_clip_min.has_value();
bool clip_nan = param_.lstm_state_clip_nan;
CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
@@ -637,10 +721,10 @@ class RNNOp {
clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN,
clip_state ? param_.lstm_state_clip_min.value() : 0.0,
clip_state ? param_.lstm_state_clip_max.value() : 0.0));
- #endif
+#endif
if (ctx.is_train) {
- #if USE_CUDNN_LSTM_PROJ
+#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
@@ -669,7 +753,7 @@ class RNNOp {
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
- #else
+#else
CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
@@ -691,9 +775,9 @@ class RNNOp {
workspace_byte_,
reserve_space_.dptr,
reserve_space_byte_));
- #endif
+#endif
} else {
- #if USE_CUDNN_LSTM_PROJ
+#if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
rnn_desc_,
x_data_desc_,
@@ -720,7 +804,7 @@ class RNNOp {
nullptr,
temp_space_.dptr,
workspace_byte_));
- #else
+#else
CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
@@ -740,22 +824,22 @@ class RNNOp {
cy_ptr,
temp_space_.dptr,
workspace_byte_));
- #endif
+#endif
}
- #endif
+#endif
if (ctx_.dev_type == kCPU) {
// allocate temp space
const size_t work_cpu_space_size =
- GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
- param_.state_size, direction, param_.mode);
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
- Storage::Get()->Free(temp_cpu_space_);
- temp_init_space_ = false;
+ Storage::Get()->Free(temp_cpu_space_);
+ temp_init_space_ = false;
}
if (!temp_init_space_) {
temp_cpu_space_ = Storage::Get()->Alloc
- (work_cpu_space_size * sizeof(DType), Context::CPU());
+ (work_cpu_space_size * sizeof(DType), Context::CPU());
temp_cpu_space_size_ = work_cpu_space_size;
temp_init_space_ = true;
}
@@ -828,7 +912,8 @@ class RNNOp {
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ size_t num_inputs = GetNumInputArguments(param_);
+
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
@@ -890,15 +975,16 @@ class RNNOp {
cx_ptr = (in_data[rnn_enum::kStateCell].get(s)).dptr_;
dcx_ptr = (in_grad[rnn_enum::kStateCell].get(s)).dptr_;
}
- if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+ if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs) {
dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get(s)).dptr_;
+ }
#if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
if (!init_cudnn_) {
Init(ctx, s, in_data, out_data);
}
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
rnn_desc_,
y_data_desc_,
@@ -1038,19 +1124,19 @@ class RNNOp {
}
}
-
private:
inline void Init(const OpContext &ctx,
mshadow::Stream *s,
const std::vector &in_data,
const std::vector &out_data) {
using namespace mshadow;
- size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+
+ size_t num_inputs = GetNumInputArguments(param_);
// kOut
size_t num_outputs = 1;
if (param_.state_outputs) {
// kOut, kStateOut, kStateCellOut
- num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3U : 2U;
}
CHECK_EQ(in_data.size(), num_inputs);
@@ -1130,7 +1216,7 @@ class RNNOp {
strideA[0] = dimA[2] * dimA[1];
strideA[1] = dimA[2];
strideA[2] = 1;
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
int dimB[3];
int strideB[3];
dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
@@ -1141,7 +1227,7 @@ class RNNOp {
strideB[1] = dimB[2];
strideB[2] = 1;
#endif
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
dtype_,
3,
@@ -1159,7 +1245,7 @@ class RNNOp {
3,
dimA,
strideA));
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
dtype_,
3,
@@ -1177,7 +1263,7 @@ class RNNOp {
3,
dimA,
strideA));
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
dtype_,
3,
@@ -1195,7 +1281,7 @@ class RNNOp {
3,
dimA,
strideA));
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
dtype_,
3,
@@ -1258,12 +1344,13 @@ class RNNOp {
}
#if CUDNN_VERSION >= 7200
if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
- (DataType::kFlag != kFloat16))
+ (DataType::kFlag != kFloat16)) {
math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
+ }
#endif
CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
#endif
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
if (param_.projection_size.has_value()) {
CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
rnn_desc_,
@@ -1272,6 +1359,13 @@ class RNNOp {
}
#endif
// Get temp space sizes
+
+ #if MXNET_USE_CUDNN_GE_7200
+ if (param_.use_sequence_length) {
+ CUDNN_CALL(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));
+ }
+ #endif
+
CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
rnn_desc_,
param_.seq_length_,
@@ -1360,8 +1454,9 @@ class RNNOp {
size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
int workspace_size_;
std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
- #if USE_CUDNN_LSTM_PROJ
+ #if MXNET_USE_CUDNN_GE_7200
cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_;
+ DType padding_fill_ = 0;
#endif
cudnnTensorDescriptor_t hx_desc_, cx_desc_;
cudnnTensorDescriptor_t hy_desc_, cy_desc_;
@@ -1387,13 +1482,22 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
const std::vector &in_types) {
const RNNParam& param = nnvm::get(attrs.parsed);
OpStatePtr state = OpStatePtr();
- MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, {
+ int dtype = in_types[rnn_enum::kData];
+ int itype = dtype;
+ if (param.use_sequence_length) {
+ itype = in_types[rnn_enum::kSequenceLength];
+ if (param.mode == rnn_enum::kLstm) itype -= 1;
+ }
+
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
if (ctx.dev_type == kGPU) {
- state = OpStatePtr::Create>(param, ctx);
+ state = OpStatePtr::Create>(param, ctx);
} else {
- state = OpStatePtr::Create>(param, ctx);
+ state = OpStatePtr::Create>(param, ctx);
}
});
+ });
return state;
}
@@ -1404,10 +1508,18 @@ void RNNStatefulCompute(const OpStatePtr& state,
const std::vector& req,
const std::vector& outputs) {
int dtype = inputs[rnn_enum::kData].type_flag_;
+
+ // Hacky. This relies on fact that seq-len type is either the last input,
+ // or we aren't using seq-len input and this type should be same as dtype.
+ // Would prefer direct access to RNNParam object here but not sure how to get.
+ int itype = inputs[inputs.size()-1].type_flag_;
+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- RNNOp& op = state.get_state>();
- op.Forward(ctx, inputs, req, outputs);
- });
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ RNNOp& op = state.get_state>();
+ op.Forward(ctx, inputs, req, outputs);
+ });
+ });
}
/*
@@ -1435,25 +1547,33 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
const std::vector &in_grad = outputs;
int dtype = inputs[rnn_enum::kData].type_flag_;
- MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- RNNOp& op = state.get_state>();
- const RNNParam& param = op.param_;
- int index = 5;
- if (param.state_outputs) {
- out_data.push_back(inputs[index++]);
- out_grad.push_back(inputs[index++]);
- }
- if (param.mode == rnn_enum::kLstm) {
- in_data.push_back(inputs[index++]);
- if (param.state_outputs) {
- out_data.push_back(inputs[index++]);
- out_grad.push_back(inputs[index]);
- }
- }
+ // Hacky. This relies on fact that seq-len type is either the last input,
+ // or we aren't using seq-len input and this type should be same as dtype.
+ // Would prefer direct access to RNNParam object here but not sure how to get.
+ int itype = inputs[inputs.size()-1].type_flag_;
- op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
- });
+ MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+ MSHADOW_TYPE_SWITCH(itype, IType, {
+ RNNOp& op = state.get_state>();
+ const RNNParam& param = op.param_;
+ int index = 5;
+ if (param.state_outputs) {
+ out_data.push_back(inputs[index++]);
+ out_grad.push_back(inputs[index++]);
+ }
+
+ if (param.mode == rnn_enum::kLstm) {
+ in_data.push_back(inputs[index++]);
+ if (param.state_outputs) {
+ out_data.push_back(inputs[index++]);
+ out_grad.push_back(inputs[index]);
+ }
+ }
+
+ op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
+ });
+ });
}
} // namespace op
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 7012a3c22f50..296d57eb4713 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -23,6 +23,9 @@
* \brief
* \author Sebastian Bodenstein
*/
+
+#include
+
#include "./rnn-inl.h"
namespace mxnet {
@@ -30,11 +33,20 @@ namespace op {
DMLC_REGISTER_PARAMETER(RNNParam);
static inline std::vector ListArguments(const RNNParam& param_) {
+ // All RNNs start off with same 3 input arguments
+ std::vector arguments{"data", "parameters", "state"};
+
+ // LSTMs also have an additional state_cell argument
if (param_.mode == rnn_enum::kLstm) {
- return {"data", "parameters", "state", "state_cell"};
- } else {
- return {"data", "parameters", "state"};
+ arguments.emplace_back("state_cell");
}
+
+ // All RNNs have option of additional sequence_length argument
+ if (param_.use_sequence_length) {
+ arguments.emplace_back("sequence_length");
+ }
+
+ return arguments;
}
static bool RNNShape(const nnvm::NodeAttrs& attrs,
@@ -42,13 +54,13 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs,
std::vector *out_shape) {
const RNNParam& param_ = nnvm::get(attrs.parsed);
using namespace mshadow;
- if (param_.mode == rnn_enum::kLstm) {
- CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, cell_state],"
- << " got in_shape->size(): " << in_shape->size();
- } else {
- CHECK_EQ(in_shape->size(), 3U) <<
- "Needed input:[data, parameters, state], got in_shape->size(): " << in_shape->size();
- }
+
+ // Query param_ object to figure out what the expectd input arguments are
+ std::vector expected_arguments = ListArguments(param_);
+
+ CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " <<
+ expected_arguments.size() << " input parameters but got " << in_shape->size() << ".";
+
const TShape &dshape = (*in_shape)[rnn_enum::kData];
if (!mxnet::ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 3U) \
@@ -77,6 +89,15 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs,
param_.mode,
param_.projection_size);
SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
+
+ // Check on sequence_length shape if using
+ if (param_.use_sequence_length) {
+ size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+ if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx;
+
+ SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size));
+ }
+
out_shape->clear();
// output: [sequence len, batch, output size]
TShape oshape = dshape;
@@ -106,6 +127,7 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs,
out_shape->push_back(cellStateShape);
}
}
+
return true;
}
@@ -113,18 +135,24 @@ static bool RNNType(const nnvm::NodeAttrs& attrs,
std::vector *in_type,
std::vector *out_type) {
const RNNParam& param_ = nnvm::get(attrs.parsed);
- if (param_.mode == rnn_enum::kLstm) {
- CHECK_EQ(in_type->size(), 4U);
- } else {
- CHECK_EQ(in_type->size(), 3U);
- }
+
+ CHECK_EQ(in_type->size(), GetNumInputArguments(param_));
+
+ size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+ if (param_.mode != rnn_enum::kLstm) --seq_len_input_idx;
+
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
+ std::vector arguments = ListArguments(param_);
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
} else {
- UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
+ // If using sequence length argument, it has its own indexing type
+ // All other input arguments must match the main data type
+ if (!(param_.use_sequence_length && i == seq_len_input_idx)) {
+ UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]);
+ }
}
}
out_type->clear();
@@ -132,8 +160,9 @@ static bool RNNType(const nnvm::NodeAttrs& attrs,
if (param_.state_outputs) {
out_type->push_back(dtype);
// Deal with lstm cell state
- if (param_.mode == rnn_enum::kLstm)
+ if (param_.mode == rnn_enum::kLstm) {
out_type->push_back(dtype);
+ }
}
return true;
}
@@ -220,7 +249,7 @@ The definition of GRU here is slightly different from paper but compatible with
.set_attr_parser(ParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get(attrs.parsed);
- return params.mode == rnn_enum::kLstm ? 4 : 3;
+ return GetNumInputArguments(params);
})
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get(attrs.parsed);
@@ -246,13 +275,13 @@ The definition of GRU here is slightly different from paper but compatible with
.set_attr("FResourceRequestEx",
[](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
std::vector request;
- const RNNParam& param = nnvm::get(attrs.parsed);
- if (param.p == 0) return request;
if (dev_mask == kGPU) {
#if MXNET_USE_CUDNN_RNN
- if (1.0f - param.p > 0) {
+ request.emplace_back(ResourceRequest::kTempSpace);
+
+ const RNNParam& param = nnvm::get(attrs.parsed);
+ if (param.p != 0 && 1.0f - param.p > 0) {
request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
- return request;
}
#endif
}
@@ -264,12 +293,15 @@ The definition of GRU here is slightly different from paper but compatible with
.add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
.add_argument("state_cell", "NDArray-or-Symbol",
"initial cell state for LSTM networks (only for LSTM)")
+.add_argument("sequence_length", "NDArray-or-Symbol",
+ "Vector of valid sequence lengths for each element in batch. (Only used if"
+ " use_sequence_length kwarg is True)")
.add_arguments(RNNParam::__FIELDS__());
NNVM_REGISTER_OP(_backward_RNN)
.set_num_outputs([](const NodeAttrs& attrs) {
const RNNParam& params = nnvm::get(attrs.parsed);
- return params.mode == rnn_enum::kLstm ? 4 : 3;
+ return GetNumInputArguments(params);
})
.set_attr_parser(ParamParser)
.set_attr("TIsLayerOpBackward", true)
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 77bb95522711..093a64a0623a 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -30,6 +30,7 @@
namespace mxnet {
namespace op {
+
NNVM_REGISTER_OP(RNN)
.set_attr("FStatefulCompute", RNNStatefulCompute);
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b80c9a54510f..9e6bead7229c 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -769,7 +769,7 @@ parameter values:
if (!dispatched && param.a_min <= 0.0 && param.a_max >= 0.0) {
const int this_stype = (*in_attrs)[0];
if (this_stype != kUndefinedStorage) {
- dispatched = storage_type_assign(&(*out_attrs)[0], kRowSparseStorage,
+ dispatched = storage_type_assign(&(*out_attrs)[0], mxnet::NDArrayStorageType(this_stype),
dispatch_mode, DispatchMode::kFComputeEx);
}
}
diff --git a/src/profiler/nvtx.cc b/src/profiler/nvtx.cc
new file mode 100644
index 000000000000..9151873aa82a
--- /dev/null
+++ b/src/profiler/nvtx.cc
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#include "nvtx.h"
diff --git a/src/profiler/nvtx.h b/src/profiler/nvtx.h
new file mode 100644
index 000000000000..c36bb50d4703
--- /dev/null
+++ b/src/profiler/nvtx.h
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#ifndef MXNET_PROFILER_NVTX_H_
+#define MXNET_PROFILER_NVTX_H_
+
+#if MXNET_USE_NVTX
+
+#include
+#include
+#include "nvToolsExt.h"
+
+namespace mxnet {
+namespace profiler {
+namespace nvtx {
+
+class NVTXDuration {
+ public:
+ explicit NVTXDuration(const char *name) noexcept
+ : range_id_(0), name_(name) {}
+
+ inline void start() {
+ range_id_ = nvtxRangeStartA(name_);
+ }
+
+ inline void stop() {
+ nvtxRangeEnd(range_id_);
+ }
+
+ private:
+ nvtxRangeId_t range_id_;
+ const char *name_;
+};
+
+
+
+} // namespace nvtx
+} // namespace profiler
+} // namespace mxnet
+
+#endif // MXNET_USE_NVTX
+#endif // MXNET_PROFILER_NVTX_H_
diff --git a/src/profiler/profiler.h b/src/profiler/profiler.h
index f1fac9ae8ddd..f9eb0af9acc1 100644
--- a/src/profiler/profiler.h
+++ b/src/profiler/profiler.h
@@ -35,6 +35,7 @@
#include
#include "./vtune.h"
#include "./aggregate_stats.h"
+#include "./nvtx.h"
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
#include
@@ -489,6 +490,12 @@ class Profiler {
#define VTUNE_ONLY_CODE(...) /* */ /* This is undefined at the bottom of this file */
#endif
+#ifdef MXNET_USE_NVTX
+#define NVTX_ONLY_CODE(...) __VA_ARGS__ /* This is undefined at the bottom of this file */
+#else
+#define NVTX_ONLY_CODE(...) /* */ /* This is undefined at the bottom of this file */
+#endif
+
/**
* _____ __ _ _ _ ____ _ _ _
* | __ \ / _|(_)| |(_) / __ \| | (_) | |
@@ -777,6 +784,7 @@ struct ProfileTask : public ProfileDuration {
categories_.set(domain_->name());
categories_.append(",task");
VTUNE_ONLY_CODE(vtune_task_.reset(new vtune::VTuneTask(name, domain->dom())));
+ NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
}
/*!
@@ -785,6 +793,7 @@ struct ProfileTask : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_task_->start());
+ NVTX_ONLY_CODE(nvtx_duration_->start());
}
/*!
@@ -792,6 +801,7 @@ struct ProfileTask : public ProfileDuration {
*/
void stop() override {
VTUNE_ONLY_CODE(vtune_task_->stop());
+ NVTX_ONLY_CODE(nvtx_duration_->stop());
SendStat();
}
@@ -831,6 +841,8 @@ struct ProfileTask : public ProfileDuration {
ProfileDomain *domain_;
/*! \brief VTune task object */
VTUNE_ONLY_CODE(std::unique_ptr vtune_task_);
+ /*! \brief NVTX duration object */
+ NVTX_ONLY_CODE(std::unique_ptr nvtx_duration_);
protected:
/*! \brief Task's start tick */
@@ -849,6 +861,7 @@ struct ProfileEvent : public ProfileDuration {
: name_(name)
, categories_("event") {
VTUNE_ONLY_CODE(vtune_event_ = vtune::VTuneEvent::registry_.get(name));
+ NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
}
/*!
@@ -857,6 +870,7 @@ struct ProfileEvent : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_event_->start());
+ NVTX_ONLY_CODE(nvtx_duration_->start());
}
/*!
@@ -905,6 +919,8 @@ struct ProfileEvent : public ProfileDuration {
profile_stat_string categories_;
/*! \brief VTune event object */
VTUNE_ONLY_CODE(vtune::VTuneEvent *vtune_event_);
+ /*! \brief NVTX duration object */
+ NVTX_ONLY_CODE(std::unique_ptr nvtx_duration_;);
protected:
/*! \brief Start time of the event */
@@ -926,6 +942,7 @@ struct ProfileFrame : public ProfileDuration {
CHECK_NOTNULL(domain);
categories_.set(domain_->name());
categories_.append(",frame");
+ NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
VTUNE_ONLY_CODE(vtune_frame_.reset(new vtune::VTuneFrame(domain->dom())));
}
@@ -935,6 +952,7 @@ struct ProfileFrame : public ProfileDuration {
void start() override {
start_time_ = ProfileStat::NowInMicrosec();
VTUNE_ONLY_CODE(vtune_frame_->start());
+ NVTX_ONLY_CODE(nvtx_duration_->start());
}
/*!
@@ -977,6 +995,8 @@ struct ProfileFrame : public ProfileDuration {
ProfileDomain *domain_;
/*! \brief VTune Frame object */
VTUNE_ONLY_CODE(std::unique_ptr vtune_frame_);
+ /*! \brief NVTX duration object */
+ NVTX_ONLY_CODE(std::unique_ptr nvtx_duration_);
protected:
/*! \brief Frame start time */
diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py
index 1b7dad487a68..f798cbc1034e 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -17,6 +17,7 @@
import mxnet as mx
import numpy as np
+from mxnet.test_utils import rand_ndarray, assert_almost_equal
from mxnet import gluon, nd
from tests.python.unittest.common import with_seed
@@ -185,7 +186,36 @@ def test_pick():
b = mx.nd.ones(shape=(256*35,))
res = mx.nd.pick(a,b)
assert res.shape == b.shape
-
+
+def test_depthtospace():
+ def numpy_depth_to_space(x, blocksize):
+ b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
+ tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w])
+ tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
+ y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize])
+ return y
+
+ shape_inp = (LARGE_X, 8, 4, 2)
+ data = rand_ndarray(shape_inp, 'default')
+ data_np = data.asnumpy()
+ expected = numpy_depth_to_space(data_np, 2)
+ output = mx.nd.depth_to_space(data, 2)
+ assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
+
+def test_spacetodepth():
+ def numpy_space_to_depth(x, blocksize):
+ b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
+ tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize])
+ tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
+ y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize])
+ return y
+
+ shape_inp = (LARGE_X, 2, 8, 4)
+ data = rand_ndarray(shape_inp, 'default')
+ data_np = data.asnumpy()
+ expected = numpy_space_to_depth(data_np, 2)
+ output = mx.nd.space_to_depth(data, 2)
+ assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
if __name__ == '__main__':
import nose
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 1c5a5835e6f9..95835fd77e9e 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -24,6 +24,7 @@
import unittest
import random
import mxnet as mx
+import mxnet.ndarray as nd
import numpy as np
import unittest
import math
@@ -225,6 +226,55 @@ def forward(self, inpt):
assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
+def check_layer_bidirectional_varseqlen(size, in_size):
+ class RefBiLSTMVarSeqLen(gluon.Block):
+ def __init__(self, size, **kwargs):
+ super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
+ with self.name_scope():
+ self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
+ self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')
+
+ def forward(self, inpt, sequence_length):
+ fwd = self._lstm_fwd(inpt)
+ bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
+ bwd = self._lstm_bwd(bwd_inpt)
+ bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
+ return nd.concat(fwd, bwd, dim=2)
+ weights = {}
+ for d in ['l', 'r']:
+ weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
+ weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size))
+ weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
+ weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
+
+ net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
+ ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
+ net.initialize()
+ ref_net.initialize()
+ net_params = net.collect_params()
+ ref_net_params = ref_net.collect_params()
+ for k in weights:
+ net_params[k].set_data(weights[k])
+ ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
+
+
+ batch_size = 10
+ num_timesteps = 11
+ data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
+
+ # TODO: figure out why int32 doesn't work here
+ sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")
+
+ net_output = net(data, sequence_length=sequence_length).asnumpy()
+ ref_net_output = ref_net(data, sequence_length).asnumpy()
+ sequence_length_np = sequence_length.asnumpy().astype("int32")
+
+ # TODO: test state return value as well output
+ # Only compare the valid sections for each batch entry
+ for b in range(batch_size):
+ assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
+
+
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_layer_bidirectional():
@@ -236,6 +286,11 @@ def test_layer_bidirectional():
def test_layer_bidirectional_proj():
check_layer_bidirectional(7, 5, 3)
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_layer_bidirectional_varseqlength():
+ check_layer_bidirectional_varseqlen(7, 5)
+
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
diff --git a/tests/python/profiling/simple_forward.py b/tests/python/profiling/simple_forward.py
new file mode 100644
index 000000000000..0ad43c89a6f5
--- /dev/null
+++ b/tests/python/profiling/simple_forward.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.gluon import nn
+
+
+def simple_forward():
+ ctx = mx.gpu()
+ mx.profiler.set_config(profile_all=True)
+ mx.profiler.set_state('run')
+
+ # define simple gluon network with random weights
+ net = nn.Sequential()
+ with net.name_scope():
+ net.add(nn.Dense(128, activation='relu'))
+ net.add(nn.Dense(64, activation='relu'))
+ net.add(nn.Dense(10))
+ net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
+
+ input = mx.nd.zeros((128,), ctx=ctx)
+ predictions = net(input)
+ print('Ran simple NN forward, results:')
+ print(predictions.asnumpy())
+
+
+if __name__ == '__main__':
+ simple_forward()
diff --git a/tests/python/profiling/test_nvtx.py b/tests/python/profiling/test_nvtx.py
new file mode 100644
index 000000000000..35b209ebb6eb
--- /dev/null
+++ b/tests/python/profiling/test_nvtx.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import unittest
+
+import mxnet as mx
+import sys
+
+from subprocess import Popen, PIPE
+
+
+def test_nvtx_ranges_present_in_profile():
+
+ if not mx.test_utils.list_gpus():
+ unittest.skip('Test only applicable to machines with GPUs')
+
+ # Build a system independent wrapper to execute simple_forward with nvprof
+ # This requires nvprof to be on your path (which should be the case for most GPU workstations with cuda installed).
+ simple_forward_path = os.path.realpath(__file__)
+ simple_forward_path = simple_forward_path.replace('test_nvtx', 'simple_forward')
+
+ process = Popen(["nvprof", sys.executable, simple_forward_path], stdout=PIPE, stderr=PIPE)
+ (output, profiler_output) = process.communicate()
+ process.wait()
+ profiler_output = profiler_output.decode('ascii')
+
+ # Verify that some of the NVTX ranges we should have created are present
+ # Verify that we have NVTX ranges for our simple operators.
+ assert "Range \"FullyConnected\"" in profiler_output
+ assert "Range \"_zeros\"" in profiler_output
+
+ # Verify that we have some expected output from the engine.
+ assert "Range \"WaitForVar\"" in profiler_output
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index d7761c76aa29..7db07596d7f8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -8357,6 +8357,18 @@ def check_concat(shape1, shape2, axis):
check_concat((8, 0, 0), (8, 0, 0), 2)
+@with_seed()
+def test_add_n():
+ data_shape = (2, 2)
+ input_num = 5
+ data = [mx.nd.random.uniform(shape=data_shape) for i in range(input_num)]
+ rslt = mx.nd.zeros(shape=data_shape)
+ for i in range(input_num):
+ rslt += data[i]
+ add_n_rslt = mx.nd.add_n(*data, out=data[0])
+ assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 7600ea944e83..3b4c684e8696 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -915,6 +915,7 @@ def check_fluent_regular(stype, func, kwargs, shape=(5, 17), equal_nan=False):
check_fluent_regular('csr', 'slice', {'begin': (2, 5), 'end': (4, 7)}, shape=(5, 17))
check_fluent_regular('row_sparse', 'clip', {'a_min': -0.25, 'a_max': 0.75})
+ check_fluent_regular('csr', 'clip', {'a_min': -0.25, 'a_max': 0.75})
for func in ['sum', 'mean', 'norm']:
check_fluent_regular('csr', func, {'axis': 0})
diff --git a/tools/dependencies/libpng.sh b/tools/dependencies/libpng.sh
index f71d4762ab34..39fa24c87ecd 100755
--- a/tools/dependencies/libpng.sh
+++ b/tools/dependencies/libpng.sh
@@ -19,7 +19,7 @@
# This script builds the static library of libpng that can be used as dependency of mxnet/opencv.
set -ex
-PNG_VERSION=1.6.34
+PNG_VERSION=1.6.35
if [[ ! -f $DEPS_PATH/lib/libpng.a ]]; then
# download and build libpng
>&2 echo "Building libpng..."
diff --git a/tools/dependencies/openssl.sh b/tools/dependencies/openssl.sh
index 8e2372c9075a..78673a3ac84b 100755
--- a/tools/dependencies/openssl.sh
+++ b/tools/dependencies/openssl.sh
@@ -19,7 +19,7 @@
# This script builds the static library of openssl that can be used as dependency of mxnet.
set -ex
-OPENSSL_VERSION=1.0.2l
+OPENSSL_VERSION=1.1.1b
if [[ ! -f $DEPS_PATH/lib/libssl.a ]] || [[ ! -f $DEPS_PATH/lib/libcrypto.a ]]; then
# download and build openssl
>&2 echo "Building openssl..."
diff --git a/tools/pip/setup.py b/tools/pip/setup.py
index 71e2549a3f19..fd9ce41c2a80 100644
--- a/tools/pip/setup.py
+++ b/tools/pip/setup.py
@@ -150,7 +150,7 @@ def has_ext_modules(self):
package_data['mxnet'].append('mxnet/libmklml_intel.so')
package_data['mxnet'].append('mxnet/libiomp5.so')
package_data['mxnet'].append('mxnet/libmkldnn.so.0')
- shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/mkldnn/include'),
+ shutil.copytree(os.path.join(CURRENT_DIR, 'mxnet-build/3rdparty/mkldnn/build/install/include'),
os.path.join(CURRENT_DIR, 'mxnet/include/mkldnn'))
if platform.system() == 'Linux':
shutil.copy(os.path.join(os.path.dirname(LIB_PATH[0]), 'libgfortran.so.3'), os.path.join(CURRENT_DIR, 'mxnet'))
diff --git a/tools/staticbuild/build_lib.sh b/tools/staticbuild/build_lib.sh
index 472bf57b101e..927c15d1dabc 100755
--- a/tools/staticbuild/build_lib.sh
+++ b/tools/staticbuild/build_lib.sh
@@ -44,6 +44,9 @@ if [[ $VARIANT == *mkl ]]; then
MKLDNN_LIBFILE='libmkldnn.0.dylib'
fi
$MAKE DEPS_PATH=$DEPS_PATH mkldnn
+ if [ ! -d lib ]; then
+ mkdir lib
+ fi
cp 3rdparty/mkldnn/build/install/lib/$IOMP_LIBFILE lib
cp 3rdparty/mkldnn/build/install/lib/$MKLML_LIBFILE lib
cp 3rdparty/mkldnn/build/install/lib/$MKLDNN_LIBFILE lib