diff --git a/.gitignore b/.gitignore index e44aa6e21464..042b96841ce5 100644 --- a/.gitignore +++ b/.gitignore @@ -216,3 +216,14 @@ patched.txt # Python type checking .mypy_cache/ .pyre/ + +# pipenv file +Pipfile +Pipfile.lock + +# conda package artifacts +conda/Dockerfile.cuda* +conda/pkg + +.envrc +*.nix \ No newline at end of file diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index a768f2f06279..32057b53eee8 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit a768f2f0627917659a4d7167eee3190469b9d164 +Subproject commit 32057b53eee870d73c6c21dc820d6546b4d9a13f diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc new file mode 100644 index 000000000000..1f25be17f72d --- /dev/null +++ b/3rdparty/bfloat16/bfloat16.cc @@ -0,0 +1,80 @@ +/* + Copyright (c) 2019 by Contributors + \file tvm/src/codegen/custom_datatypes/mybfloat16.cc + \brief Small bfloat16 library for use in unittests + + Code originally from TensorFlow; taken and simplified. Original license: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include +#include +#include + +void FloatToBFloat16(const float* src, uint16_t* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p += 2, q++, size--) { + *q = p[0]; + } +#else + for (; size != 0; p += 2, q++, size--) { + *q = p[1]; + } +#endif +} + +void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p++, q += 2, size--) { + q[0] = *p; + q[1] = 0; + } +#else + for (; size != 0; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +#endif +} + +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, + size_t size) { + float a_f, b_f; + BFloat16ToFloat(a, &a_f, 1); + BFloat16ToFloat(b, &b_f, 1); + float out_f = a_f + b_f; + FloatToBFloat16(&out_f, dst, 1); +} + +extern "C" { +TVM_DLL TVM_DLL uint16_t FloatToBFloat16_wrapper(float in) { + uint16_t out; + FloatToBFloat16(&in, &out, 1); + return out; +} + +TVM_DLL float BFloat16ToFloat_wrapper(uint16_t in) { + float out; + BFloat16ToFloat(&in, &out, 1); + return out; +} + +TVM_DLL uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b) { + uint16_t out; + BFloat16Add(&a, &b, &out, 1); + return out; +} +} diff --git a/3rdparty/dlpack b/3rdparty/dlpack index 5c792cef3aee..0acb731e0e43 160000 --- a/3rdparty/dlpack +++ b/3rdparty/dlpack @@ -1 +1 @@ -Subproject commit 5c792cef3aee54ad8b7000111c9dc1797f327b59 +Subproject commit 0acb731e0e43d15deee27b66f10e4c5b4e667913 diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 82bf4c2e2af3..3943914eed66 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 82bf4c2e2af312b3d52513aa727483803a2f8734 +Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f diff --git a/CMakeLists.txt b/CMakeLists.txt index dceb9f46568e..80b121477631 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,13 @@ tvm_option(USE_SGX "Build with SGX" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) +tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) + +# 3rdparty libraries +tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") +tvm_option(DMLC_PATH "Path to DMLC" "3rdparty/dmlc-core/include") +tvm_option(RANG_PATH "Path to RANG" "3rdparty/rang/include") +tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") # Contrib library options tvm_option(USE_BLAS "The blas library to be linked" none) @@ -52,11 +59,12 @@ tvm_option(USE_TENSORRT "Build with TensorRT, must have CUDA and CUDNN enabled" tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) # include directories +include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") -include_directories("3rdparty/dlpack/include") -include_directories("3rdparty/dmlc-core/include") -include_directories("3rdparty/rang/include") -include_directories("3rdparty/compiler-rt") +include_directories(${DLPACK_PATH}) +include_directories(${DMLC_PATH}) +include_directories(${RANG_PATH}) +include_directories(${COMPILER_RT_PATH}) # initial variables set(TVM_LINKER_LIBS "") @@ -90,8 +98,13 @@ else(MSVC) set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS} -rdynamic") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS} -rdynamic") else() - set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden ${CMAKE_C_FLAGS}") - set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11 ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") + if (HIDE_PRIVATE_SYMBOLS) + message("Hide private symbols...") + set(CMAKE_C_FLAGS "-fvisibility=hidden ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "-fvisibility=hidden ${CMAKE_CXX_FLAGS}") + endif(HIDE_PRIVATE_SYMBOLS) endif () if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) @@ -123,6 +136,8 @@ file(GLOB_RECURSE RELAY_SRCS ) list(APPEND COMPILER_SRCS ${RELAY_SRCS}) +file(GLOB DATATYPE_SRCS src/codegen/datatype/*.cc) +list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) @@ -152,6 +167,8 @@ if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) endif() +list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc) + if(USE_RPC) message(STATUS "Build with RPC support...") file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc) @@ -209,6 +226,7 @@ add_library(tvm_runtime_static STATIC ${RUNTIME_SRCS}) if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG") +else() set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG") endif(USE_RELAY_DEBUG) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4d4515e09410..2f4e0b65190b 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -55,7 +55,8 @@ We do encourage everyone to work anything they are interested in. - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Leyuan Wang](https://github.com/Laurawly): @Laurawly: - topi - [Yao Wang](https://github.com/kevinthesun): @kevinthesun: - topi, vision -- [Eddie Yan](https://github.com/eqy): @eqy - runtime, autotvm, rpc, topi +- [Jian Weng](https://github.com/were): @were: - hybrid script +- [Eddie Yan](https://github.com/eqy) (PMC): @eqy - runtime, autotvm, rpc, topi - [Lianmin Zheng](https://github.com/merrymercy) (PMC): @merrymercy - autotvm, topi, relay ## Reviewers @@ -82,6 +83,7 @@ We do encourage everyone to work anything they are interested in. - [Kazutaka Morita](https://github.com/kazum): @kazum - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 +- [Josh Pollock](https://github.com/joshpoll): @joshpoll - [Jared Roesch](https://github.com/jroesch): @jroesch - [Siva](https://github.com/srkreddy1238): @srkreddy1238 - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel @@ -94,6 +96,7 @@ We do encourage everyone to work anything they are interested in. - [Leyuan Wang](https://github.com/Laurawly): @Laurawly - [Jian Weng](https://github.com/were): @were - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene +- [Bing Xu](https://github.com/antinucleon): @antinucleon - [Eddie Yan](https://github.com/eqy): @eqy - [Joshua Z. Zhang](https://github.com/zhreshold): @zhreshold - [Lianmin Zheng](https://github.com/merrymercy): @merrymercy diff --git a/Jenkinsfile b/Jenkinsfile index af5a2ce3eb42..3b6fc032a389 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,7 +39,7 @@ // - Periodically cleanup the old versions on local workers // ci_lint = "tvmai/ci-lint:v0.51" -ci_gpu = "tvmai/ci-gpu:v0.51" +ci_gpu = "tvmai/ci-gpu:v0.52" ci_cpu = "tvmai/ci-cpu:v0.50" ci_i386 = "tvmai/ci-i386:v0.50" @@ -135,6 +135,7 @@ stage('Build') { echo set\\(USE_CUDNN ON\\) >> config.cmake echo set\\(USE_CUDA ON\\) >> config.cmake echo set\\(USE_OPENGL ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-6.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake @@ -157,6 +158,7 @@ stage('Build') { echo set\\(USE_OPENCL ON\\) >> config.cmake echo set\\(USE_ROCM ON\\) >> config.cmake echo set\\(USE_VULKAN ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER clang-6.0\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake @@ -174,12 +176,14 @@ stage('Build') { cd build cp ../cmake/config.cmake . echo set\\(USE_SORT ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake + echo set\\(HIDE_PRIVATE_SYMBOLS ON\\) >> config.cmake """ make(ci_cpu, 'build', '-j4') pack_lib('cpu', tvm_lib) @@ -202,6 +206,7 @@ stage('Build') { cd build cp ../cmake/config.cmake . echo set\\(USE_SORT ON\\) >> config.cmake + echo set\\(USE_MICRO ON\\) >> config.cmake echo set\\(USE_RPC ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(USE_LLVM llvm-config-5.0\\) >> config.cmake @@ -277,6 +282,17 @@ stage('Integration Test') { } } }, + 'legacy: GPU': { + node('GPU') { + ws('workspace/tvm/legacy-python-gpu') { + init_git() + unpack_lib('gpu', tvm_multilib) + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_legacy.sh" + } + } + } + }, 'docs: GPU': { node('GPU') { ws('workspace/tvm/docs-python-gpu') { diff --git a/apps/android_rpc/README.md b/apps/android_rpc/README.md index 36ace85405c0..38725917f424 100644 --- a/apps/android_rpc/README.md +++ b/apps/android_rpc/README.md @@ -141,7 +141,7 @@ export TVM_NDK_CC=/opt/android-toolchain-arm64/bin/aarch64-linux-android-g++ python android_rpc_test.py ``` -This will compile TVM IR to shared libraries (CPU, OpenCL and Vulkan) and run vector addition on your Android device. To verify compiled TVM IR shared libraries on OpenCL target set [`'test_opencl = True'`](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py#L25) and on Vulkan target set [`'test_vulkan = False'`](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py#L27) in [tests/android_rpc_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py), by default on CPU target will execute. +This will compile TVM IR to shared libraries (CPU, OpenCL and Vulkan) and run vector addition on your Android device. To verify compiled TVM IR shared libraries on OpenCL target set `'test_opencl = True'` and on Vulkan target set `'test_vulkan = True'` in [tests/android_rpc_test.py](https://github.com/dmlc/tvm/blob/master/apps/android_rpc/tests/android_rpc_test.py), by default on CPU target will execute. On my test device, it gives following results. ```bash diff --git a/cmake/config.cmake b/cmake/config.cmake index e7ddb9aba6b8..679de8d7e752 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -135,9 +135,6 @@ set(USE_TENSORRT OFF) # Build ANTLR parser for Relay text format set(USE_ANTLR OFF) -# Build TSIM for VTA -set(USE_VTA_TSIM OFF) - # Whether use Relay debug mode set(USE_RELAY_DEBUG OFF) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 1adb0aaf387a..6d5ea000edc2 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -29,8 +29,7 @@ elseif(PYTHON) --use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json) endif() - execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target) - string(STRIP ${__vta_target} VTA_TARGET) + execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE) message(STATUS "Build VTA runtime with target: " ${VTA_TARGET}) @@ -44,6 +43,13 @@ elseif(PYTHON) add_library(vta SHARED ${VTA_RUNTIME_SRCS}) + if(${VTA_TARGET} STREQUAL "tsim") + target_compile_definitions(vta PUBLIC USE_TSIM) + include_directories("vta/include") + file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) + list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) + endif() + target_include_directories(vta PUBLIC vta/include) foreach(__def ${VTA_DEFINITIONS}) @@ -55,18 +61,12 @@ elseif(PYTHON) set_target_properties(vta PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") endif(APPLE) - # PYNQ rules for Pynq v2.3 + # PYNQ rules for Pynq v2.4 if(${VTA_TARGET} STREQUAL "pynq") find_library(__cma_lib NAMES cma PATH /usr/lib) target_link_libraries(vta ${__cma_lib}) endif() - if(NOT USE_VTA_TSIM STREQUAL "OFF") - include_directories("vta/include") - file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc) - list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS}) - endif() - else() message(STATUS "Cannot found python in env, VTA build is skipped..") endif() diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index e1e151d6a9f8..a47f83771d37 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -27,7 +27,11 @@ elseif(USE_BLAS STREQUAL "mkl") if(NOT IS_DIRECTORY ${USE_MKL_PATH}) set(USE_MKL_PATH /opt/intel/mkl) endif() - find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + if(APPLE) + find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + elseif(UNIX) + find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + endif() include_directories(${USE_MKL_PATH}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) diff --git a/conda/Dockerfile.cuda100 b/conda/Dockerfile.template similarity index 72% rename from conda/Dockerfile.cuda100 rename to conda/Dockerfile.template index def8c9ac5d6a..59b9ac96814e 100644 --- a/conda/Dockerfile.cuda100 +++ b/conda/Dockerfile.template @@ -15,7 +15,13 @@ # specific language governing permissions and limitations # under the License. -FROM nvidia/cuda:10.0-devel-centos6 +FROM nvidia/cuda:{{ cuda_version }}-devel-centos6 + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ + tar --no-same-owner -xzf cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -C /usr/local && \ + rm cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz && \ + ldconfig + RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ chmod +x ~/miniconda.sh && \ @@ -30,4 +36,4 @@ ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 WORKDIR /workspace RUN chmod -R a+w /workspace -CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: 10.0}' /workspace/conda/tvm-libs +CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: {{ cuda_version }}}' /workspace/conda/tvm-libs diff --git a/conda/build_cuda.sh b/conda/Makefile old mode 100755 new mode 100644 similarity index 66% rename from conda/build_cuda.sh rename to conda/Makefile index 2f3207e22987..cda546ac73ce --- a/conda/build_cuda.sh +++ b/conda/Makefile @@ -5,22 +5,18 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#/bin/sh -condadir=`dirname $0` -condadir=`readlink -f $condadir` -srcdir=`dirname $condadir` -docker build -t tvm-cuda100-forge $condadir -f $condadir/Dockerfile.cuda100 -docker run --rm -v $srcdir:/workspace tvm-cuda100-forge -docker build -t tvm-cuda92-forge $condadir -f $condadir/Dockerfile.cuda92 -docker run --rm -v $srcdir:/workspace tvm-cuda92-forge -sudo chown -R `whoami` $condadir/pkg +packages: + conda build tvm-libs + conda build tvm + conda build topi + conda built nnvm diff --git a/conda/build_cuda.py b/conda/build_cuda.py new file mode 100644 index 000000000000..47af6ce4564e --- /dev/null +++ b/conda/build_cuda.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import subprocess + +from jinja2 import Template + +CUDA_VERSIONS = ['10.0', '9.0'] + + +# Make sure that the cudnn version you set here is available +# for all the cuda versions that you want both from nvidia +# and from conda. + +# These two must be in sync +CUDNN_FULL_VERSION = '7.3.1.20' +CUDNN_VERSION = '7.3.1' + + +condadir = os.path.dirname(sys.argv[0]) +condadir = os.path.abspath(condadir) +srcdir = os.path.dirname(condadir) + + +with open(os.path.join(condadir, 'Dockerfile.template')) as f: + docker_template = Template(f.read()) + + +def render_dockerfile(version): + txt = docker_template.render(cuda_version=version, + cudnn_short_version=CUDNN_VERSION, + cudnn_version=CUDNN_FULL_VERSION) + fname = os.path.join(condadir, + 'Dockerfile.cuda' + version.replace('.', '')) + with open(fname, 'w') as f: + f.write(txt) + return fname + + +def build_docker(version): + vv = version.replace('.', '') + fname = render_dockerfile(version) + tagname = f'tvm-cuda{ vv }-forge' + subprocess.run(['docker', 'build', '-t', tagname, + condadir, '-f', fname], check=True) + return tagname + + +def build_pkg(version): + tagname = build_docker(version) + subprocess.run(['docker', 'run', '--rm', '-v', f'{ srcdir }:/workspace', + tagname], check=True) + + +if __name__ == '__main__': + build_versions = CUDA_VERSIONS + if len(sys.argv) > 1: + build_versions = sys.argv[1:] + for version in build_versions: + build_pkg(version) diff --git a/conda/nnvm/meta.yaml b/conda/nnvm/meta.yaml index 883655f335cb..d948484a61e5 100644 --- a/conda/nnvm/meta.yaml +++ b/conda/nnvm/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 skip: True # [win] requirements: diff --git a/conda/topi/meta.yaml b/conda/topi/meta.yaml index bbba452a6422..f4bc8950d4c4 100644 --- a/conda/topi/meta.yaml +++ b/conda/topi/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 requirements: host: diff --git a/conda/tvm-libs/build.sh b/conda/tvm-libs/build.sh index d4cf2578b570..e0b85910475e 100644 --- a/conda/tvm-libs/build.sh +++ b/conda/tvm-libs/build.sh @@ -16,42 +16,25 @@ # specific language governing permissions and limitations # under the License. -# Fix for OSX build to hide the clang LLVM -rm -f ${BUILD_PREFIX}/bin/llvm-config -rm -rf ${BUILD_PREFIX}/lib/cmake - set -e -if [ -z "$PREFIX" ]; then - PREFIX="$CONDA_PREFIX" -fi - -if [ -z "$cuda" ] || [ "$cuda" == "False" ]; then - CUDA_OPT="" +if [ "$cuda" == "True" ]; then + CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" else - CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON" + CUDA_OPT="" fi if [ "$target_platform" == "osx-64" ]; then # macOS 64 bits METAL_OPT="" # Conda can only target 10.9 for now - TOOLCHAIN_OPT="" else METAL_OPT="" - if [ "$target_platform" == "linux-64" ]; then - # Linux 64 bits - TOOLCHAIN_OPT="-DCMAKE_TOOLCHAIN_FILE=${RECIPE_DIR}/../cross-linux.cmake" - else - # Windows (or 32 bits, which we don't support) - METAL_OPT="" - TOOLCHAIN_OPT="" - fi fi rm -rf build || true mkdir -p build cd build -cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=ON -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" $TOOLCHAIN_OPT .. +cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" .. make -j${CPU_COUNT} VERBOSE=1 make install cd .. diff --git a/conda/tvm-libs/meta.yaml b/conda/tvm-libs/meta.yaml index 5126f5b30359..aad8f251c2a6 100644 --- a/conda/tvm-libs/meta.yaml +++ b/conda/tvm-libs/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 string: cuda{{ cuda_version }}_{{ PKG_BUILDNUM }} # [cuda] requirements: @@ -39,6 +39,7 @@ requirements: - zlib # [linux] run: - {{ pin_compatible('cudatoolkit', lower_bound=cuda_version, max_pin='x.x') }} # [cuda] + - {{ pin_compatible('cudnn', lower_bound='7.3.1', max_pin='x') }} # [cuda] about: home: https://github.com/dmlc/tvm diff --git a/conda/tvm/meta.yaml b/conda/tvm/meta.yaml index 693237ce07c0..221dc7950f75 100644 --- a/conda/tvm/meta.yaml +++ b/conda/tvm/meta.yaml @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 0 + number: 1 requirements: build: diff --git a/dmlc_tvm_commit_id.txt b/dmlc_tvm_commit_id.txt index 9eb98e3754f0..ceb215ac275d 100644 --- a/dmlc_tvm_commit_id.txt +++ b/dmlc_tvm_commit_id.txt @@ -1 +1 @@ -25c91d34c4de744cc9428944ccb1e84a72476ce5 +cbec5b94b87455f07918f7f4488c9a82a2d26708 diff --git a/vta/apps/tsim_example/cmake/modules/driver.cmake b/docker/Dockerfile.ci_jekyll similarity index 76% rename from vta/apps/tsim_example/cmake/modules/driver.cmake rename to docker/Dockerfile.ci_jekyll index c4c80637918f..5d3cf86dd6f5 100644 --- a/vta/apps/tsim_example/cmake/modules/driver.cmake +++ b/docker/Dockerfile.ci_jekyll @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. -file(GLOB TSIM_SW_SRC src/driver.cc) -add_library(driver SHARED ${TSIM_SW_SRC}) -target_include_directories(driver PRIVATE ${VTA_DIR}/include) +# CI docker Jekyll env for building website +# tag: v0.50 +FROM ubuntu:16.04 -if(APPLE) - set_target_properties(driver PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") -endif(APPLE) +RUN apt-get update && apt-get install -y sudo wget +RUN apt-get update && apt-get install -y ruby-full build-essential zlib1g-dev +RUN gem install jekyll bundler diff --git a/docker/install/ubuntu_install_nodejs.sh b/docker/install/ubuntu_install_nodejs.sh index 681f1b4a4641..8da9e2485797 100755 --- a/docker/install/ubuntu_install_nodejs.sh +++ b/docker/install/ubuntu_install_nodejs.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,7 +25,7 @@ apt-get install -y curl # The node install script fetched and executed here will update the # apt source list, hence the second apt-get update is necessary. -curl -s -S -L https://deb.nodesource.com/setup_6.x | bash - +curl -s -S -L https://deb.nodesource.com/setup_8.x | bash - apt-get update apt-get install -y nodejs diff --git a/docker/install/ubuntu_install_onnx.sh b/docker/install/ubuntu_install_onnx.sh index ec5b7f3b4964..a073389472b2 100755 --- a/docker/install/ubuntu_install_onnx.sh +++ b/docker/install/ubuntu_install_onnx.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,7 +21,7 @@ set -u set -o pipefail # fix to certain version for now -pip3 install onnx>=1.1.0 +pip3 install onnx>=1.4.1 pip3 install https://download.pytorch.org/whl/cu80/torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl pip3 install torchvision diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index d70f9890053d..802fb3b87d8c 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -35,7 +35,7 @@ pip2 install flatbuffers # Setup tflite from schema mkdir tflite cd tflite -wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/contrib/lite/schema/schema.fbs +wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs flatc --python schema.fbs cat <setup.py diff --git a/docs/_static/img/tvm-logo-square.png b/docs/_static/img/tvm-logo-square.png new file mode 100644 index 000000000000..37822d1a2d22 Binary files /dev/null and b/docs/_static/img/tvm-logo-square.png differ diff --git a/docs/api/python/dev.rst b/docs/api/python/dev.rst index e4b207bf4cbc..7bb938ca7517 100644 --- a/docs/api/python/dev.rst +++ b/docs/api/python/dev.rst @@ -61,6 +61,7 @@ tvm.ir_pass tvm.ir_pass.CanonicalSimplify tvm.ir_pass.StorageFlatten tvm.ir_pass.VectorizeLoop + tvm.ir_pass.SkipVectorize tvm.ir_pass.UnrollLoop tvm.ir_pass.ThreadSync tvm.ir_pass.StorageRewrite diff --git a/docs/api/python/relay/build_module.rst b/docs/api/python/relay/build_module.rst index 28dadea21e78..26164bf1ade9 100644 --- a/docs/api/python/relay/build_module.rst +++ b/docs/api/python/relay/build_module.rst @@ -22,17 +22,9 @@ tvm.relay.build_module .. autofunction:: tvm.relay.build_module.build -.. autofunction:: tvm.relay.build_module.build_config - .. autofunction:: tvm.relay.build_module.optimize .. autofunction:: tvm.relay.build_module.create_executor -.. autoclass:: tvm.relay.build_module.BuildConfig - :members: - -.. autofunction:: tvm.relay.build_module.build_config - :members: - .. autoclass:: tvm.relay.build_module.GraphExecutor :members: diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst new file mode 100644 index 000000000000..4eb7f9d8fea7 --- /dev/null +++ b/docs/api/python/relay/transform.rst @@ -0,0 +1,45 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.transform +---------------------- + +.. automodule:: tvm.relay.transform + +.. autofunction:: tvm.relay.transform.build_config + +.. autofunction:: tvm.relay.transform.module_pass + +.. autofunction:: tvm.relay.transform.function_pass + +.. autoclass:: tvm.relay.transform.Pass + :members: + +.. autoclass:: tvm.relay.transform.PassInfo + :members: + +.. autoclass:: tvm.relay.transform.PassContext + :members: + +.. autoclass:: tvm.relay.transform.ModulePass + :members: + +.. autoclass:: tvm.relay.transform.FunctionPass + :members: + +.. autoclass:: tvm.relay.transform.Sequential + :members: diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index eaa5dacd678e..ade0f1a5b390 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -88,6 +88,7 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.all topi.logical_and topi.logical_or topi.logical_not @@ -98,6 +99,8 @@ List of operators topi.shape topi.layout_transform topi.image.resize + topi.argsort + topi.topk List of schedules @@ -140,6 +143,7 @@ topi .. autofunction:: topi.gather_nd .. autofunction:: topi.full .. autofunction:: topi.full_like +.. autofunction:: topi.all .. autofunction:: topi.max .. autofunction:: topi.sum .. autofunction:: topi.min @@ -161,6 +165,8 @@ topi .. autofunction:: topi.tile .. autofunction:: topi.shape .. autofunction:: topi.layout_transform +.. autofunction:: topi.argsort +.. autofunction:: topi.topk topi.nn ~~~~~~~ diff --git a/docs/conf.py b/docs/conf.py index e458b5549207..a1b66325a527 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -165,6 +165,8 @@ html_logo = "_static/img/tvm-logo-small.png" +html_favicon = "_static/img/tvm-logo-square.png" + # Output file base name for HTML help builder. htmlhelp_basename = project + 'doc' diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 788f1f8b50a3..e6df9becdc9e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -33,9 +33,9 @@ At the root of the TVM repository, we have following subdirectories that togethe - ``topi`` - Compute definitions and backend schedules for standard neural network operators. - ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. After the introduction of Relay, it remains in the codebase for backward compatibility. -Using standard Deep Learning terminologies, ``src/relay`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in the rest of ``src``. ``python`` provides python bindings for the C++ API and driver code that users can use to execute compilation. Operators corresponding to each node are registered in ``src/relay/op``. Implementations for operators are in ``topi``, and they are coded in either C++ or Python. +Using standard Deep Learning terminology, ``src/relay`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructure implemented in the rest of ``src``. ``python`` provides python bindings for the C++ API and driver code that users can use to execute compilation. Operators corresponding to each node are registered in ``src/relay/op``. Implementations of operators are in ``topi``, and they are coded in either C++ or Python. -Relay is the new IR for deep networks that is intended to replace NNVM. If you have used NNVM, Relay provides equivalent or better functionalities. In fact, Relay goes beyond a traditional way of thinking deep networks in terms of computational graphs. But for the purpose of this document, we can think of Relay as a traditional computational graph framework. You can read more about Relay `here `_. +Relay is the new IR for deep networks that is intended to replace NNVM. If you have used NNVM, Relay provides equivalent or better functionality. In fact, Relay goes beyond a traditional way of thinking deep networks in terms of computational graphs. But for the purpose of this document, we can think of Relay as a traditional computational graph framework. You can read more about Relay `here `_. When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compiler.build(...)`` for the older API), the following sequence of actions happens for each node in the graph: @@ -43,7 +43,7 @@ When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compile - Generate a compute expression and a schedule for the operator - Compile the operator into object code -One of the interesting aspects of TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. +One of the interesting aspects of the TVM codebase is that interoperability between C++ and Python is not unidirectional. Typically, all code that performs heavy lifting is implemented in C++, and Python bindings are provided for the user interface. This is also true in TVM, but in the TVM codebase, C++ code can also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. ******************************************* Vector Add Example @@ -84,7 +84,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ args[4]); }); -We use ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of `PackedFunc `_. ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy. +We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc `_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy. A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``. @@ -141,7 +141,7 @@ Bound inference is the process where all loop bounds and sizes of intermediate b .. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html -``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects that changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. +``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below. @@ -173,7 +173,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ } -``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: +The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: :: @@ -182,9 +182,9 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ *rv = BuildCUDA(args[0]); }); -``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. +The ``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. -``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. +The ``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages the CUDA driver API. The ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects. @@ -243,4 +243,4 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal } }; -This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, at the end all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase. +This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, in the end, all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase. diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst new file mode 100644 index 000000000000..a59620a0a861 --- /dev/null +++ b/docs/dev/virtual_machine.rst @@ -0,0 +1,314 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Putting the VM in TVM: The Relay Virtual Machine +================================================ + +Relay, a new program representation, has enabled the representation and optimization of +a great breadth of machine learning programs. +Unfortunately, by supporting a more expressive set of programs, we have +introduced several new execution challenges. + +Relay's interpreter can execute the full language but has notable limitations +that make it unsuited for production deployments. It is structured as an inefficient +interpreter that performs AST traversal to execute the program. This approach is conceptually +simple but inefficient, as the AST traversal heavily relies on indirection. + +There are further challenges in compiling dynamic code, such as dynamic scheduling and allocation, +fully dynamic tensor shapes, and control flow. The interpreter offers simple solutions +for these, but none is sufficiently compelling or optimized. + +The second execution mechanism is the existing graph runtime. In order to target Relay +programs to this, we compile a small subset of them to the old graph format and execute +them on the runtime. Graph runtime provides a fast execution experience but only for a very limited +subset of Relay programs. + +An alternative but not-standard approach is Relay's ahead-of-time compiler, +which compiles a Relay program into a shared library containing an ahead- +of-time implementation. The ahead-of-time compiler provides compelling performance +but is difficult to extend and instrument, which can only be done by modifying the +code generation and optimization mechanisms. + +The Relay virtual machine is intended to be a framework that balances these competing +approaches, providing a dynamic execution environment which can be extended, instrumented, +and integrated with other approaches like ahead-of-time compilation via a flexible extension +mechanism. + +The virtual machine is designed to strike a balance between performance and flexibility +when deploying and executing Relay programs, without giving up the benefits of TVM. + +Virtual machine (VM) design is a well-studied area in programming languages and systems, +and there have been various virtual machine designs for both full-fledged +and embedded programing languages. +Previous language VM designs have been heavily tailored to the execution profile of traditional programs. +Traditional programs manipulate small scalar values and consist of a large number of low-level instructions. +The sheer quantity of instructions requires instruction execution and dispatch to be extremely efficient. +In the context of machine learning we manipulate primarily tensor values, using a (relatively) +low number of high level instructions. ML programs' cost centers are expensive operator invocations, +such as GEMM or convolution, over a large input. Due to the execution profile exhibited by ML programs, +micro-optimizations present in scalar VMs are dramatically less important. + +TVM has provided strong support for vision models, +but we want to grow to support a wider variety of models. +The graph runtime is able to utilize the fully static nature of the input graphs to perform +aggressive optimization such as fully static allocation, and optimal memory reuse. +When we introduce models which make use of control flow, recursion, dynamic shapes, and dynamic +allocation, we must change how execution works. A virtual machine for Relay is a natural choice. + +The rest of this document provides a high-level overview of the Relay +virtual machine design and its instruction set. + +Design +------ + +The VM's design is focused on simplicity without sacrificing performance. +In order to accomplish this we have focused on designing a tensor VM rather than a scalar VM. + +In the tensor VM setting, we optimize for cheap “allocation” of objects (by trying to avoid real allocation), +reuse of static fragments, and the ability to do dynamic shape (i.e jagged tensors). + +Instruction Set +~~~~~~~~~~~~~~~ + +The choices of an instruction set and instruction representation are the most critical design decisions for a VM. +The current representation of the instructions is a tagged union containing the op-code and the data payload. An important design decision is the level of abstraction of the instructions (RISC vs. CISC) and how they take their data (fixed-width instruction encoding vs. variable-length encoding). The current version is closer to CISC, with complex instructions like AllocTensor, and is variable-length due to the inclusion of the shape as part of the instruction. The current instruction set is very high-level and corresponds roughly to high-level operations in Relay. + +Ret +^^^ +**Arguments**: +:: + RegName dst + RegName result + +Returns the object in register `result` to caller's register `dst`. + +InvokePacked +^^^^^^^^^^^^ +**Arguments**: +:: + size_t packed_index + size_t arity + size_t output_size + RegName* packed_args + +Invoke the packed function denoted by `packed_index`. The `arity` +and `output_size` are used to inform the VM how many inputs and +outputs to expect. `packed_args` stores the list of argument registers. + +AllocTensor +^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + RegName shape_register + size_t ndim + DLDataType dtype + +Allocate a tensor value of the appropriate shape (stored in `shape_register`) and `dtype`. The result +is saved to register `dst`. + +AllocDatatype +^^^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t tag + size_t num_fields + RegName* datatype_fields + +Allocate a data type with the tag `tag` using the `num_fields` entries +from registers `datatype_fields`. The result is saved to register `dst`. + +AllocClosure +^^^^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t clo_index + size_t num_freevar + RegName* free_vars; + +Allocate a closure with the VMFunction at `clo_index` as +its code, and the `num_freevar` entries from registers in +`free_vars`. The result is saved to register `dst`. + +GetField +^^^^^^^^ +**Arguments**: +:: + RegName dst + RegName object + size_t field_index + +Get the field value with index `field_index` from `object`. And saves the result to register `dst`. + +If +^^ +**Arguments**: +:: + RegName if_cond + size_t true_offset + size_t false_offset + +Check if the object at register `if_cond` is `true` or `false`. +If `true`, relative jump by `true_offset`, else relative +jump by `false_offset`. + +Goto +^^^^ +**Arguments**: +:: + size_t pc_offset + +Relative unconditional jump by `pc_offset`. + +Invoke +^^^^^^ +**Arguments**: +:: + size_t func_index + +Invoke function at `func_index`, consumes the number of arguments contained in the VMFunction's +arity field. + +InvokeClosure +^^^^^^^^^^^^^ +**Arguments**: +:: + RegName closure + size_t closure_args_num + RegName* closure_args + +Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction. + +LoadConst +^^^^^^^^^ +**Arguments**: +:: + RegName dst + size_t const_index + +Load the constant at `const_index` from the constant pool. The result is saved to register `dst`. + +Object Representation +~~~~~~~~~~~~~~~~~~~~~ +We use a simple object representation that uses shared pointers and tagging. +There is a huge space of possible object representations trade-offs, but we +believe micro-optimizing this code has little to no effect on the end-to-end performance. + +:: + + struct ObjectCell { + ObjectTag tag; + ... + }; + + struct Object { + std::shared_ptr ptr; + ... + } + +See `include/tvm/runtime/vm.h` for more details. + +Currently, we support 3 types of objects: tensors, data types, and closures. + +:: + + VMObject VMTensor(const tvm::runtime::NDArray& data); + VMObject VMDatatype(size_t tag, const std::vector& fields); + VMObject VMClosure(size_t func_index, std::vector free_vars); + + +Stack and State +~~~~~~~~~~~~~~~ + +The Relay VM maintains a stack frame, which contains information about how to resume the +previous call. Registers are allocated in a continuous space (virtual register file) for each function. + +We keep track of a set of Relay functions we have called, a pointer into its bytecode, an offset into the byte code (known as the program counter). + +:: + + struct VirtualMachine { + ... + std::vector frames; + ... + // Current function. + size_t func_index; + // Pointer into the current function's instructions. + const Instruction* code; + // Current program counter relative to the code pointer. + size_t pc; + ... + }; + + +Dispatch Loop +~~~~~~~~~~~~~ +A critical piece of a VM is the dispatch loop. The dispatch loop usually dominates the execution time of a +virtual machine, but we have experimentally found this not to be the case for Relay. We have just implemented +a simple `switch`/`goto` dispatch loop which dispatches based on instruction op code. + +This loop is implemented by `VirtualMachine::Run()`. + +VM Compiler +~~~~~~~~~~~ + +An important part of this infrastructure is a compiler from Relay's full IR into a sequence of bytecode. +The VM compiler transforms a `tvm::relay::Module` into a `tvm::relay::vm::VirtualMachine`. The virtual +machine contains a set of compiled functions, the compiled functions are contained in `tvm::relay::vm::Function`. The functions contain metadata about the the function as well as its compiled bytecode. For full definitions of the data structures see `vm.h`. + +Optimizations +~~~~~~~~~~~~~ + +There are quite a few optimizations required by the VM compiler. + +We have implemented them in the old pass style, but plan to port them to +the new pass manager (#2546) before merging. + +Optimizations marked with `TODO` are not implemented yet. + +- A-Normal Form +- Lambda Lift (see `src/relay/vm/lambda_lift.cc`) +- Inline Primitives (see `src/relay/vm/inline_primitives.cc`) +- Inliner (see `src/relay/pass/inliner.cc`) +- Constant Pool Layout (see `src/relay/backend/vm/compiler.cc`) +- ADT Tag Allocation (see `src/relay/backend/vm/compiler.cc`) +- Tail Call Optimization (TODO) +- Liveness Analysis (TODO) + +Serialization +~~~~~~~~~~~~~ + +A final and yet-to-be-implemented part of the VM design is serialization. The accompanying PR will introduce both the bytecode and its serialization, as well as VM-level serialization. The design premise is that a VM can be efficiently stored to disk and resumed at a later time. This would also allow us to efficiently schedule many models on to a single machine in order to obtain good utilization. + +Unresolved Questions +~~~~~~~~~~~~~~~~~~~~ + +How do we handle dynamic shapes? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +TODO + +How can we modify the VM to support JIT compilation of certain code paths? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In the code generation space there are still many tradeoffs to be analyzed and the VM is designed +to be very flexible so we can modify it for future experiments. + +How do we support heterogenous execution? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Heterogenous execution should work out of the box assuming we have annotated the appropriate device copies. +In order to do this properly we need to run the device annotation and copying passes. diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 3a769dee2dce..1ea8f3478341 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -192,6 +192,12 @@ Python dependencies .. code:: bash pip install --user tornado psutil xgboost + + * If you want to parse Relay text format progams, you must use Python 3 and run the following + + .. code:: bash + + pip install --user mypy orderedset antlr4-python3-runtime Install Contrib Libraries diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index cd5677293571..28ee99e77981 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -135,6 +135,7 @@ This level enables additional math and transform operators. tvm.relay.greater_equal tvm.relay.less tvm.relay.less_equal + tvm.relay.all tvm.relay.logical_and tvm.relay.logical_or tvm.relay.logical_not @@ -171,6 +172,7 @@ This level enables additional math and transform operators. :nosignatures: tvm.relay.argsort + tvm.relay.topk **Level 10: Temporary Operators** @@ -277,6 +279,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal +.. autofunction:: tvm.relay.all .. autofunction:: tvm.relay.logical_and .. autofunction:: tvm.relay.logical_or .. autofunction:: tvm.relay.logical_not @@ -307,6 +310,7 @@ Level 5 Definitions Level 6 Definitions ------------------- .. autofunction:: tvm.relay.argsort +.. autofunction:: tvm.relay.topk Level 10 Definitions diff --git a/docs/vta/install.md b/docs/vta/install.md index 233bb5ca0260..6c87b4edd288 100644 --- a/docs/vta/install.md +++ b/docs/vta/install.md @@ -84,7 +84,7 @@ This guide covers the following themes: Setup your Pynq board based on the [Pynq board getting started tutorial](http://pynq.readthedocs.io/en/latest/getting_started.html). You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). -* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.3](http://www.pynq.io/board.html) (released October 3rd 2018), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). +* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.4](http://www.pynq.io/board.html)(released February 22rd 2019), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). * For this test setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Ethernet setup instructions. To be able to talk to the board, make sure to [assign your computer a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: @@ -208,7 +208,7 @@ chmod u+x Xilinx_Vivado_SDK_Web_2018.2_0614_1954_Lin64.bin #### Xilinx Vivado GUI Installer Steps -At this point you've launched the Vivado 2017.1 Installer GUI program. +At this point you've launched the Vivado 2018.2 Installer GUI program. 1. Click “Next” on the *Welcome* screen. 2. On the *Select Install Type* screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next” . diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 9a8d9d372956..c506268cb14b 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -48,11 +48,7 @@ namespace arith { // Forward declare Analyzer class Analyzer; -/*! - * \brief reference class to ConstIntBoundNode - * \sa ConstIntBoundNode - */ -class ConstIntBound; + /*! * \brief Constant integer up and lower bound(inclusive). * Useful for value bound analysis. @@ -69,8 +65,6 @@ class ConstIntBoundNode : public Node { v->Visit("max_value", &max_value); } - TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value); - /*! \brief Number to represent +inf */ static const constexpr int64_t kPosInf = std::numeric_limits::max(); /*! @@ -83,7 +77,23 @@ class ConstIntBoundNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node); }; -TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode); +/*! + * \brief reference class to ConstIntBoundNode + * \sa ConstIntBoundNode + */ +class ConstIntBound : public NodeRef { + public: + /*! + * \brief constructor by fields. + * \param min_value The mininum value. + * \param max_value The maximum value. + */ + TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value); + + static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; + TVM_DEFINE_NODE_REF_METHODS(ConstIntBound, NodeRef, ConstIntBoundNode); +}; /*! * \brief Analyzer to get constant integer bound over expression. @@ -133,11 +143,6 @@ class ConstIntBoundAnalyzer { Impl* impl_; }; -/*! - * \brief reference of ModularSetNode - * \sa ModularSetNode - */ -class ModularSet; /*! * \brief Range of a linear integer function. * Use to do specify the possible index values. @@ -162,13 +167,20 @@ class ModularSetNode : public Node { v->Visit("base", &base); } - TVM_DLL static ModularSet make(int64_t coeff, int64_t base); - static constexpr const char* _type_key = "arith.ModularSet"; TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node); }; -TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode); +/*! + * \brief reference of ModularSetNode + * \sa ModularSetNode + */ +class ModularSet : public NodeRef { + public: + TVM_DLL ModularSet(int64_t coeff, int64_t base); + + TVM_DEFINE_NODE_REF_METHODS(ModularSet, NodeRef, ModularSetNode); +}; /*! * \brief Analyzer to get modular information over expression. @@ -278,14 +290,14 @@ class CanonicalSimplifier { }; /*! - * \brief A RAII constraint context. + * \brief Constraint context. * * \code * * Var("x"); * arith::Analyzer analyzer; * { - * arith::ConstraintContext cctx(&analyzer, x % 3 == 0); + * With scope(&analyzer, x % 3 == 0); * CHECK_EQ(analyzer.modular_set(x)->coeff, 3); * } * // constraint no longer in effect. @@ -294,88 +306,36 @@ class CanonicalSimplifier { * \endcode */ class ConstraintContext { - public: + private: + // declare friend to enable with. + friend class With; /*! * \brief Construct a constraint context. * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION; - /*! \brief destructor */ - ~ConstraintContext() DMLC_THROW_EXCEPTION { - exit_(); - } - - private: + ConstraintContext(Analyzer* analyzer, Expr constraint) + : analyzer_(analyzer), constraint_(constraint) {} + // enter the scope. + void EnterWithScope(); + // exit the scope. + void ExitWithScope(); + /*! \brief The analyzer */ + Analyzer* analyzer_; + /*! \brief The constraint */ + Expr constraint_; /*! \brief function to be called in recovery */ std::function exit_; }; -/*! - * \brief Analyzer that contains bunch of sub-analyzers. - * - * Each sub-analyzer can make use of another sub-analyzer - * by weak reference of this. - * - * NOTE for sub-analyzer developers: - * If the analyzer uses memoization, we need to clear the internal - * cache when information about a Var has been overrideen. - */ -class Analyzer { - public: - /*! \brief sub-analyzer: const integer bound */ - ConstIntBoundAnalyzer const_int_bound; - /*! \brief sub-analyzer: modular set */ - ModularSetAnalyzer modular_set; - /*! \brief sub-analyzer rewrite simplify */ - RewriteSimplifier rewrite_simplify; - /*! \brief sub-analyzer canonical simplify */ - CanonicalSimplifier canonical_simplify; - /*! \brief constructor */ - Analyzer(); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to expr. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param expr The expression we bind to. - */ - void Bind(const VarExpr& var, const Expr& expr); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to a range. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param range The range we bind to. - */ - void Bind(const VarExpr& var, const Range& range); - /*! - * \brief Whether can we proof expr >= val. - - * Non-negative proof is very useful in integer analysis - * to lower divisions and mods given difference in trunc and ceil mode. - * - * \param expr The expression. - * \param lower_bound The lower bound. - * \return Whether we can proof it. - * - * \note Analyzer will call into sub-analyzers to get the result. - */ - bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); -}; - //----------------------------------------------- -// Integer set abstraction API. +// Integer set data structure. // // This is a API build on top of the base // integer analysis API to provide set analysis. //------------------------------------------------ /*! - * \brief Sign of an expression or set. + * \brief Sign type of an integer expression. */ enum SignType { kPositive, @@ -384,8 +344,13 @@ enum SignType { kUnknown }; -// internal node container of int set. -struct IntSetNode; +/*! + * \brief Base class of all IntSet containers. + */ +struct IntSetNode : public Node { + static constexpr const char* _type_key = "IntSet"; + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +}; /*! * \brief Integer set class, represent a set of integers in one dimension. @@ -407,11 +372,6 @@ class IntSet : public NodeRef { * \return The covering range. */ Range cover_range(Range max_range) const; - /*! - * \brief find an interval that covers the set. - * \return The covering interval set. - */ - IntSet cover_interval() const; /*! \return Lower bound of the set */ Expr min() const; /*! \return upper bound of the set */ @@ -476,33 +436,91 @@ class IntSet : public NodeRef { }; /*! - * \brief Base class of all IntSet containers. + * \brief Integer set analyzer. */ -struct IntSetNode : public Node { - static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +class IntSetAnalyzer { + public: + /*! + * \brief Find a symbolic integer set that contains all possible values of + * expr given the domain of each variables. + * + * \param expr The expression of interest. + * \param dom_map The domain map to indicate which variable to relax. + * \return the result of the analysis. + */ + IntSet operator()(const Expr& expr, const Map& dom_map); + + private: + friend class Analyzer; + explicit IntSetAnalyzer(Analyzer* parent); + ~IntSetAnalyzer(); + class Impl; + /*! \brief Internal impl */ + Impl* impl_; }; /*! - * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] - * Where coeff[i] and base are invariant of var[j] for all i and j. + * \brief Analyzer that contains bunch of sub-analyzers. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return [coeff[i]] if it is possible, empty array if it is not. - */ -Array DetectLinearEquation(const Expr& e, const Array& vars); - -/*! - * \brief Detect if expression corresponds to clip bound of the vars + * Each sub-analyzer can make use of another sub-analyzer + * by weak reference of this. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value - * return empty if the e does not match the pattern. + * NOTE for sub-analyzer developers: + * If the analyzer uses memoization, we need to clear the internal + * cache when information about a Var has been overridden. */ -Array DetectClipBound(const Expr& e, const Array& vars); +class Analyzer { + public: + /*! \brief sub-analyzer: const integer bound */ + ConstIntBoundAnalyzer const_int_bound; + /*! \brief sub-analyzer: modular set */ + ModularSetAnalyzer modular_set; + /*! \brief sub-analyzer rewrite simplify */ + RewriteSimplifier rewrite_simplify; + /*! \brief sub-analyzer canonical simplify */ + CanonicalSimplifier canonical_simplify; + /*! \brief sub-analyzer: int set */ + IntSetAnalyzer int_set; + /*! \brief constructor */ + Analyzer(); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to expr. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param expr The expression we bind to. + */ + void Bind(const VarExpr& var, const Expr& expr); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to a range. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const VarExpr& var, const Range& range); + /*! + * \brief Whether can we prove expr >= val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param lower_bound The lower bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); +}; +//----------------------------------------------- +// Integer set legacy API. +//------------------------------------------------ /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -621,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond, */ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); +// Expression pattern detector. +/*! + * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] + * Where coeff[i] and base are invariant of var[j] for all i and j. + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return [coeff[i]] if it is possible, empty array if it is not. + */ +Array DetectLinearEquation(const Expr& e, + const Array& vars); + +/*! + * \brief Detect if expression corresponds to clip bound of the vars + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value + * return empty if the e does not match the pattern. + */ +Array DetectClipBound(const Expr& e, + const Array& vars); + // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/include/tvm/base.h b/include/tvm/base.h index ae2d91ff8523..f358f7f5d447 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -39,21 +39,24 @@ using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ - class TypeName : public ::tvm::NodeRef { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ - }; \ +/*! + * \brief Macro to define common node ref methods. + * \param TypeName The name of the NodeRef. + * \param BaseTypeName The Base type. + * \param NodeName The node container type. + */ +#define TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseTypeName, NodeName) \ + TypeName() {} \ + explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseTypeName(n) {} \ + const NodeName* operator->() const { \ + return static_cast(node_.get()); \ + } \ + operator bool() const { return this->defined(); } \ + using ContainerType = NodeName; /*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. + * \brief Macro to define CopyOnWrite function in a NodeRef. + * \param NodeName The Type of the Node. * * CopyOnWrite will generate a unique copy of the internal node. * The node will be copied if it is referenced by multiple places. @@ -70,25 +73,77 @@ using ::tvm::AttrVisitor; * * \endcode */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TypeName() {} \ - explicit TypeName(::tvm::NodePtr<::tvm::Node> n) : BaseType(n) {} \ - const NodeName* operator->() const { \ - return static_cast(node_.get()); \ - } \ - inline NodeName* CopyOnWrite() { \ +#define TVM_DEFINE_NODE_REF_COW(NodeName) \ + NodeName* CopyOnWrite() { \ CHECK(node_ != nullptr); \ if (!node_.unique()) { \ NodePtr n = make_node(*(operator->())); \ NodePtr(std::move(n)).swap(node_); \ } \ return static_cast(node_.get()); \ - } \ - using ContainerType = NodeName; \ + } + +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ::tvm::NodeRef { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, ::tvm::NodeRef, NodeName); \ + }; \ + +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_NODE_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_NODE_REF_COW(NodeName); \ }; +/*! + * \brief RAII wrapper function to enter and exit a context object + * similar to python's with syntax. + * + * \code + * // context class + * class MyContext { + * private: + * friend class With; + MyContext(arguments); + * void EnterWithScope(); + * void ExitWithScope(); + * }; + * + * { + * With scope(arguments); + * // effect take place. + * } + * \endcode + * + * \tparam ContextType Type of the context object. + */ +template +class With { + public: + /*! + * \brief constructor. + * Enter the scope of the context. + */ + template + explicit With(Args&& ...args) + : ctx_(std::forward(args)...) { + ctx_.EnterWithScope(); + } + /*! \brief destructor, leaves the scope of the context. */ + ~With() DMLC_THROW_EXCEPTION { + ctx_.ExitWithScope(); + } + + private: + /*! \brief internal context type. */ + ContextType ctx_; +}; /*! * \brief save the node as well as all the node it depends on as json. diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 208f086f86c0..187a74552241 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -37,7 +37,7 @@ namespace tvm { /*! * \brief Container for target device information. -* Use target::llvm, target::cuda etc functions instead of constructing directly. +* Use target::llvm, target::cuda etc functions instead of constructing directly. */ class TargetNode : public Node { public: @@ -89,65 +89,47 @@ class TargetNode : public Node { mutable std::string str_repr_; }; +/*! \brief reference cpass to the target. */ class Target : public NodeRef { public: Target() {} explicit Target(NodePtr n) : NodeRef(n) {} - /*! * \brief Create a Target given a string * \param target_str the string to parse */ - TVM_DLL static Target create(const std::string& target_str); - - /*! - * \brief Push a new target context onto the thread local stack. The Target on top of - * the stack is used to determine which specialization to use when invoking a GenericFunc. - * \param target The target to set as the current context. - */ - TVM_DLL static void EnterTargetScope(const tvm::Target& target); - - /*! - * \brief Pop a target off the thread local context stack, restoring the previous target - * as the current context. - */ - TVM_DLL static void ExitTargetScope(); - + TVM_DLL static Target Create(const std::string& target_str); /*! - * \brief Get the current target context from thread local storage. - * \param allow_not_defined If the context stack is empty and this is set to true, an - * undefined Target will be returned. Otherwise, an empty context stack will cause a - * runtime error. - * \return The target that is the current context. The target may not be defined if - * allow_not_defined is true. - */ - TVM_DLL static tvm::Target current_target(bool allow_not_defined = true); + * \brief Get the current target context from thread local storage. + * \param allow_not_defined If the context stack is empty and this is set to true, an + * undefined Target will be returned. Otherwise, an empty context stack will cause a + * runtime error. + * \return The target that is the current context. The target may not be defined if + * allow_not_defined is true. + */ + TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - inline const TargetNode* operator->() const { + const TargetNode* operator->() const { return static_cast(node_.get()); } using ContainerType = TargetNode; -}; - -/*! - * \brief RAII container to provide a scoped target context. Pushes a target onto the - * context stack when constructed, and pops it when destructed. - */ -struct TargetContext { + class Internal; + private: + // enable with syntax. + friend class Internal; + friend class With; /*! - * \brief Enter a new target context. The given target becomes the new current context. - * When the TargetContext is destructed, the previous context is restored. - * \param target The target to set as the new current context. + * \brief Push a new target context onto the thread local stack. + * The Target on top of the stack is used to determine which + * specialization to use when invoking a GenericFunc. */ - explicit TargetContext(const tvm::Target& target) { - Target::EnterTargetScope(target); - } - - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~TargetContext() { - Target::ExitTargetScope(); - } + TVM_DLL void EnterWithScope(); + /*! + * \brief Pop a target off the thread local context stack, + * restoring the previous target as the current context. + */ + TVM_DLL void ExitWithScope(); }; /*! \brief This namespace provides functions to construct Target instances */ @@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector& options = } // namespace target -class BuildConfig; - /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class BuildConfigNode : public Node { public: /*! @@ -246,6 +226,9 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable select rewriting. */ bool disable_select_rewriting = false; + /*! \brief Whether to disable loop vectorization. */ + bool disable_vectorize = false; + void VisitAttrs(AttrVisitor* v) final { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); @@ -260,6 +243,7 @@ class BuildConfigNode : public Node { v->Visit("dump_pass_ir", &dump_pass_ir); v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("disable_select_rewriting", &disable_select_rewriting); + v->Visit("disable_vectorize", &disable_vectorize); } static constexpr const char* _type_key = "BuildConfig"; @@ -267,69 +251,48 @@ class BuildConfigNode : public Node { }; /*! -* \brief Container for build configuration options -*/ + * \brief Build configuration for compilations. + */ class BuildConfig : public ::tvm::NodeRef { public: BuildConfig() {} explicit BuildConfig(NodePtr<::tvm::Node> n) : NodeRef(n) {} - const BuildConfigNode* operator->() const { return static_cast(node_.get()); } - BuildConfigNode* operator->() { return static_cast(node_.get()); } - /*! - * \brief Push a new BuildConfig context onto the thread local stack. - * \param build_config The configuration to set as the current context. - */ - TVM_DLL static void EnterBuildConfigScope(const tvm::BuildConfig& build_config); - - /*! - * \brief Pop a build config off the thread local context stack, restoring the previous - * configuration as the current context. + * \brief Construct a BuildConfig containing a empty build config node. + * \return The new BuildConfig */ - TVM_DLL static void ExitBuildConfigScope(); - + TVM_DLL static BuildConfig Create(); /*! * \brief Get the current BuildConfig context from thread local storage, or a default * configuration if a BuildConfig scope has not been entered. * \return The configuration that is the current context. */ - TVM_DLL static tvm::BuildConfig Current(); + TVM_DLL static BuildConfig Current(); using ContainerType = BuildConfigNode; -}; + class Internal; -/*! - * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the - * context stack when constructed, and pops it when destructed. - */ -struct BuildConfigContext { + private: + // Enable with syntax. + friend class With; /*! - * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current - * context. When the BuildConfigContext is destructed, the previous context is restored. - * \param build_config The BuildConfig to set as the new current context. + * \brief Push a new BuildConfig context onto the thread local stack. */ - explicit BuildConfigContext(const tvm::BuildConfig& build_config) { - BuildConfig::EnterBuildConfigScope(build_config); - } + TVM_DLL void EnterWithScope(); - /*! \brief Destructor. Pops the context off the thread local stack. */ - ~BuildConfigContext() { - BuildConfig::ExitBuildConfigScope(); - } + /*! + * \brief Pop a build config off the thread local context stack, + * restoring the previous configuration as the current context. + */ + TVM_DLL void ExitWithScope(); }; -/*! -* \brief Construct a BuildConfig containing a new BuildConfigNode -* \return The new BuildConfig -*/ -TVM_DLL BuildConfig build_config(); - /*! * \brief Build a LoweredFunc given a schedule, args and binds * \param sch The schedule to lower. diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 4ef3effaf251..f289bdd810d5 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -33,6 +33,7 @@ #include "ir.h" namespace tvm { + /*! * \brief Make a const value with certain data type. * \param t The target type. @@ -427,6 +428,13 @@ TVM_DLL Expr abs(Expr x); */ TVM_DLL Expr sum(Expr source, Array axis); +/*! + * \brief logical And of of source expression over axis + * \param source The source expression. + * \param axis List of iteration variables that will be used for reduction. + */ +TVM_DLL Expr all(Expr source, Array axis); + /*! * \brief max of of source expression over axis * \param source The source expression. @@ -551,6 +559,12 @@ inline Expr MakeConstScalar(Type t, ValueType value) { if (t.is_int()) return ir::IntImm::make(t, static_cast(value)); if (t.is_uint()) return ir::UIntImm::make(t, static_cast(value)); if (t.is_float()) return ir::FloatImm::make(t, static_cast(value)); + // For now, we store const scalar values of custom datatypes within doubles; later, during the + // datatypes lowering pass, we will lower the value to its true representation in the format + // specified by the datatype. + // TODO(gus) when do we need to start worrying about doubles not being precise enough? + if (static_cast(t.code()) >= static_cast(kCustomBegin)) + return ir::FloatImm::make(t, static_cast(value)); LOG(FATAL) << "cannot make const for type " << t; return Expr(); } diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 20b56e0676eb..e1c92e50e6ad 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -250,35 +250,42 @@ Stmt UnrollLoop(Stmt stmt, /*! * \brief vectorize the constant loops - * \param stmt The statment to be vectorized. + * \param stmt The statement to be vectorized. * \return Transformed stmt. */ Stmt VectorizeLoop(Stmt stmt); +/*! + * \brief convert vectorized loops into serialized loops + * \param stmt The statement to skip vectorization on. + * \return Transformed stmt. + */ +Stmt SkipVectorize(Stmt stmt); + /*! * \brief instruments bound checkers. -* \param stmt The statment to be instrumented. -* \return Instrumented Stmt. +* \param stmt The statement to be instrumented. +* \return Instrumented stmt. */ Stmt InstrumentBoundCheckers(Stmt stmt); /*! * \brief Inject virtual thread loops into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectVirtualThread(Stmt stmt); /*! * \brief Inject prefetch instructions into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \return Transformed stmt. */ Stmt InjectPrefetch(Stmt stmt); /*! * \brief Inject double buffer into stmt. - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param split_loop Loop splitting factor. * \return Transformed stmt. */ @@ -287,7 +294,7 @@ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); /*! * \brief Inject copy intrinsics with optional pad. * - * \param stmt The statment to be transformed. + * \param stmt The statement to be transformed. * \param pragma_key The pragma key for hint of copy. * \param fintrin The function with signature * @@ -308,7 +315,7 @@ Stmt InjectCopyIntrin(Stmt stmt, * Trying to share space between allocations to make * a static allocation plan when possible. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt StorageRewrite(Stmt stmt); @@ -324,7 +331,7 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop); /*! * \brief Detect and insert sync points to co-processor. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt CoProcSync(Stmt stmt); @@ -332,7 +339,7 @@ Stmt CoProcSync(Stmt stmt); /*! * \brief Lift common attrs with attr_key to outer scope. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \param attr_key The attribute key to be checked. * \return Transformed stmt. */ @@ -340,7 +347,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); /*! * \brief Detect and rewrite unsafe select that contains memory access. - * \param stmt The statment to be rewritten. + * \param stmt The statement to be rewritten. * \return Transformed stmt. */ Stmt RewriteUnsafeSelect(Stmt stmt); @@ -349,7 +356,7 @@ Stmt RewriteUnsafeSelect(Stmt stmt); * \brief Lower attached storage access information. * Do this pass after all storage access analysis finish. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt LowerStorageAccessInfo(Stmt stmt); @@ -358,7 +365,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt); * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. * - * \param stmt The stmt to be trasnformed + * \param stmt The stmt to be transformed * \return Transformed stmt. */ Stmt DecorateDeviceScope(Stmt stmt); @@ -381,7 +388,7 @@ Stmt DecorateDeviceScope(Stmt stmt); * \return a LoweredFunc with the specified signiture. * * \note - * The function signiture have two cases + * The function signature have two cases * * let num_packed_args = len(api_args) - num_unpacked_args; * @@ -500,6 +507,17 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); +/*! + * \brief Lower custom datatypes. + * + * See tvm::datatypes::Registry for more information on adding custom datatypes. + * + * \param f The device function to be lowered. + * \param target The target device. + * \return Transformed function. + */ +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 15a8c1215177..38dc39bbe7a7 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -286,6 +286,8 @@ class TensorComputeOpNode : public BaseComputeOpNode { Array inputs; /*! \brief region of input tensors */ Array input_regions; + /*! \brief scalar expression inputs */ + Array scalar_inputs; /*! \brief constructor */ TensorComputeOpNode() {} // override functions @@ -314,6 +316,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("intrin", &intrin); v->Visit("inputs", &inputs); v->Visit("input_regions", &input_regions); + v->Visit("scalar_inputs", &scalar_inputs); } static Operation make(std::string name, std::string tag, @@ -322,7 +325,8 @@ class TensorComputeOpNode : public BaseComputeOpNode { int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions); + Array regions, + Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode); diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 20f135c11bba..ce14a6a2d535 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_ALGORITHM_H_ #include +#include #include namespace tvm { @@ -48,6 +49,31 @@ struct ArgsortAttrs : public tvm::AttrsNode { } }; +struct TopKAttrs : public tvm::AttrsNode { + int k; + int axis; + bool is_ascend; + std::string ret_type; + DataType dtype; + + TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { + TVM_ATTR_FIELD(k).set_default(1) + .describe("Number of top elements to select"); + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis along which to sort the input tensor."); + TVM_ATTR_FIELD(ret_type).set_default("both") + .describe("The return type [both, values, indices]." + "both - return both top k data and indices." + "values - return top k data only." + "indices - return top k indices only."); + TVM_ATTR_FIELD(is_ascend).set_default(false) + .describe("Whether to sort in ascending or descending order." + "By default, sort in descending order"); + TVM_ATTR_FIELD(dtype).set_default(NullValue()) + .describe("Data type of the output indices."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index f3caf213575d..e74c0f53ad65 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_IMAGE_H_ #include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 8a1aca0a4b4a..0fbf984817b9 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_NN_H_ #include +#include #include namespace tvm { diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 1b82412d0482..5e315185ea7f 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_TRANSFORM_H_ #include +#include #include namespace tvm { @@ -101,7 +102,8 @@ struct TakeAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(mode).set_default("clip") .describe("Specify how out-of-bound indices will behave." "clip - clip to the range (default)" - "wrap - wrap around the indices"); + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 11b4ebfcfaad..b98bbfc5988e 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_VISION_H_ #include +#include #include namespace tvm { @@ -79,10 +80,16 @@ struct MultiBoxTransformLocAttrs /*! \brief Attributes used in get_valid_counts operator */ struct GetValidCountsAttrs : public tvm::AttrsNode { double score_threshold; + int id_index; + int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { TVM_ATTR_FIELD(score_threshold).set_default(0.0) .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0) + .describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1) + .describe("Index of the scores/confidence of boxes."); } }; diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index 6b9a1fa7b7c6..5189fd982d37 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -83,7 +83,7 @@ struct Error : public dmlc::Error { * * The final mode represents the old mode, if we report an error that has no span or * expression, we will default to throwing an exception with a textual representation - * of the error and no indication of where it occured in the original program. + * of the error and no indication of where it occurred in the original program. * * The latter mode is not ideal, and the goal of the new error reporting machinery is * to avoid ever reporting errors in this style. diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 15c96bb12822..68b7ccab99c7 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - Constructor constructor; + int tag; tvm::Array fields; + /*! \brief Optional field tracking ADT constructor. */ + Constructor constructor; + void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("constructor", &constructor); + v->Visit("tag", &tag); v->Visit("fields", &fields); + v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(Constructor constructor, - tvm::Array fields); + TVM_DLL static ConstructorValue make(int tag, + tvm::Array fields, + Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 6441fb3f5b9c..638f75968fd3 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -87,14 +87,14 @@ class ModuleNode : public RelayNode { * \param update Controls whether you can replace a definition in the * environment. */ - void Add(const GlobalVar& var, const Function& func, bool update = false); + TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. * \param type The type definition. */ - void AddDef(const GlobalTypeVar& var, const TypeData& type); + TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); /*! * \brief Add a function to the global environment. @@ -103,69 +103,69 @@ class ModuleNode : public RelayNode { * * It does not do type inference as Add does. */ - void AddUnchecked(const GlobalVar& var, const Function& func); + TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); /*! * \brief Update a function in the global environment. * \param var The name of the global function to update. * \param func The new function. */ - void Update(const GlobalVar& var, const Function& func); + TVM_DLL void Update(const GlobalVar& var, const Function& func); /*! * \brief Remove a function from the global environment. * \param var The name of the global function to update. */ - void Remove(const GlobalVar& var); + TVM_DLL void Remove(const GlobalVar& var); /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalVar GetGlobalVar(const std::string& str); + TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. * \returns The global variable. */ - GlobalTypeVar GetGlobalTypeVar(const std::string& str); + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ - Function Lookup(const GlobalVar& var); + TVM_DLL Function Lookup(const GlobalVar& var) const; /*! * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - Function Lookup(const std::string& name); + TVM_DLL Function Lookup(const std::string& name) const; /*! * \brief Lookup a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ - TypeData LookupDef(const GlobalTypeVar& var); + TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const; /*! * \brief Lookup a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TypeData LookupDef(const std::string& var); + TVM_DLL TypeData LookupDef(const std::string& var) const; /*! * \brief Update the functions inside this environment by * functions in another environment. * \param other The other environment. */ - void Update(const Module& other); + TVM_DLL void Update(const Module& other); /*! \brief Construct a module from a standalone expression. * @@ -177,7 +177,7 @@ class ModuleNode : public RelayNode { * * \returns A module with expr set as the entry point. */ - static Module FromExpr( + TVM_DLL static Module FromExpr( const Expr& expr, const tvm::Map& global_funcs = {}); diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 31067925fa63..fff630f55eb7 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -20,51 +20,18 @@ /*! * \file tvm/relay/pass.h * \brief The set of Relay passes written in C++. - * - * This file also implements a pass manager. The pass manager manages a sequence - * of Relay-to-Relay transformation passes over a particlar unit of AST. The - * design is largely inspired from LLVM's pass manager and modern deep learning - * frameworks that perform tensor->tensor transformations. - * - * The responsibilities of a traditional compiler pass manager usually involves: - * - Organizing the execution order of optimization passes though not - * necessarily in the optimal sequence. - * - Collecting required analysis information and keep them up-to-date. - * - Reducing the effort required to implement new passes for compiler - * developers, etc. - * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work - * different granularity, i.e. module level, function level, and even sequential - * passe that contains a host of passes. - * - * However, we also extend the functionality of the traditional pass manager - * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass - * manager performs the Relay.Module -> Relay.Module transformation. All - * different types of passes, including the sequential-level pass object, are - * essentially pass objects. This design, therefore, effectively provides users - * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with - * the pass manager, external users will be able to have custom passes correctly - * scheduled without having to modify a single handcrafted pass order. - * - * In the future we need to describe constraints between passes. For example, - * we may want to preserve dependencies between different passes and validate - * them on the completion of a certain pass. - * - * We also need to store side information and import the error reporting system. - */ + */ #ifndef TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_ #include #include -#include #include #include #include #include #include +#include #include #include #include @@ -72,174 +39,6 @@ namespace tvm { namespace relay { -namespace pass { - -/* - * \brief The context of pass. - */ -class PassContext; - -/*! - * \brief PassContextNode contains the information that a pass can rely on, such as - * analysis results. - */ -class PassContextNode : public RelayNode { - public: - /*! - * \brief The error reporter used to notify users why an optimization fails. - */ - ErrorReporter err_reporter; - - PassContextNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - } - - TVM_DLL static PassContext make(); - - static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassContext, PassContextNode) - -/* - * \brief The meta data of a pass. - * - * PassInfo can be extended conveniently in the future if more meta information - * is needed. - */ -class PassInfo; - -/*! - * \brief PassInfoNode contains meta data that will be used to help optimization - * and analysis. - */ -class PassInfoNode : public RelayNode { - public: - /*! \brief The minimal optimization level that this pass will be enabled. */ - int opt_level; - - /*! \brief The name of an optimization/analysis pass. */ - std::string name; - - /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; - - PassInfoNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("opt_level", &opt_level); - v->Visit("name", &name); - v->Visit("required", &required); - } - - TVM_DLL static PassInfo make(int opt_level, std::string name, - tvm::Array required); - - static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); -}; - -TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) - -class Pass; - -/*! - * \brief PassNode is the base type of differnt types of optimization passes. - * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. - */ -class PassNode : public RelayNode { - public: - /* - * \brief Get the pass information/meta data. */ - virtual PassInfo Info() const = 0; - - /*! - * \brief Set the context information for a pass. - * - * \param pass_ctx The context information for a certain pass. - */ - virtual void SetContext(const PassContext& pass_ctx) = 0; - - /*! - * \brief Execute the optimization pass using a functor. - * - * \param mod The module that an optimization pass runs on. - * - * \return The updated module. - */ - virtual Module operator()(const Module& mod) const = 0; - - void VisitAttrs(tvm::AttrVisitor* v) override {} - - static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); -}; - -class Pass : public NodeRef { - public: - Pass() = default; - explicit Pass(NodePtr p) : NodeRef(p) {} - - PassNode* operator->() const { - return static_cast(this->node_.get()); - } - - using ContainerType = PassNode; -}; - -/* - * \brief Create a module pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the module pass. - * \param name The name of the module pass. - * \param required The list of the passes that the module pass is dependent on. - * - * \return The created module pass. - */ -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); - -/* - * \brief Create a function pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the function pass. - * \param name The name of the function pass. - * \param required The list of the passes that the function pass is dependent on. - * - * \return The created function pass. - */ -Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); -/* - * \brief Create a sequential pass. - * - * \param passes The optimization passes will be performed. - * \param opt_level The optimization level of the sequential pass. - * \param name The name of the sequential pass. - * \param required The list of the passes that the sequential pass is dependent on. - * \param disabled The disabled passes. - * - * \return The created sequential pass. - */ -Pass CreateSequentialPass(const tvm::Array& passes, - int opt_level, - const std::string& name, - const tvm::Array& required, - const tvm::Array& disabled); - -} // namespace pass - /*! * \brief Infer the type of an expression. * @@ -286,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, */ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); -/*! \brief Compare two expressions for structural equivalence. +/*! + * \brief Compare two expressions for structural equivalence. * * This comparison operator respects scoping and compares * expressions without regard to variable choice. @@ -303,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod); */ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); -/*! \brief Compare two types for structural equivalence. +/*! + * \brief Compare two types for structural equivalence. * * This comparison operator respects scoping and compares * expressions without regard to variable choice. @@ -321,7 +122,26 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2); */ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); -/*! \brief Add abstraction over a function +/*! + * \brief Compare two patterns for structural equivalence. + * + * This comparison operator respects scoping and compares + * patterns without regard to variable choice. + * + * For example: `A(x, _, y)` is equal to `A(z, _, a)`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence + * for more details. + * + * \param t1 The left hand pattern. + * \param t2 The right hand pattern. + * + * \return true if equal, otherwise false + */ +TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); + +/*! + * \brief Add abstraction over a function * * For example: `square` is transformed to * `fun x -> square x`. @@ -337,7 +157,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); */ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); -/*! \brief Check that each Var is only bound once. +/*! + * \brief Check that each Var is only bound once. * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * @@ -350,7 +171,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); */ TVM_DLL bool WellFormed(const Expr& expr); -/*! \brief Get all bound variables from expression expr. +/*! + * \brief Get all bound variables from expression expr. * * Bound variables are all variables that are declared in the expr. * They only have meaning inside that expr, and can only be used in it. @@ -361,7 +183,8 @@ TVM_DLL bool WellFormed(const Expr& expr); */ TVM_DLL tvm::Array BoundVars(const Expr& expr); -/*! \brief Get all bound variables from pattern pat. +/*! + * \brief Get all bound variables from pattern pat. * * Bound variables are all variables that got bound by the pat. * They only have meaning inside that expr, and can only be used in it. @@ -372,7 +195,8 @@ TVM_DLL tvm::Array BoundVars(const Expr& expr); */ TVM_DLL tvm::Array BoundVars(const Pattern& pat); -/*! \brief Get free type parameters from expression expr. +/*! + * \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a * let or a function parameter in the context. @@ -383,7 +207,8 @@ TVM_DLL tvm::Array BoundVars(const Pattern& pat); */ TVM_DLL tvm::Array FreeVars(const Expr& expr); -/*! \brief Get all variables from expression expr. +/*! + * \brief Get all variables from expression expr. * * \param expr the expression. * @@ -391,7 +216,8 @@ TVM_DLL tvm::Array FreeVars(const Expr& expr); */ TVM_DLL tvm::Array AllVars(const Expr& expr); -/*! \brief Get free TypeVars from expression expr. +/*! + * \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function * type in the context. @@ -403,7 +229,8 @@ TVM_DLL tvm::Array AllVars(const Expr& expr); */ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get free TypeVars from type t. +/*! + * \brief Get free TypeVars from type t. * * Free type parameters are type parameters that are not bound by a function * type in the context. @@ -415,7 +242,8 @@ TVM_DLL tvm::Array FreeTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); -/*! \brief Get all bound type variables from expression expr. +/*! + * \brief Get all bound type variables from expression expr. * * Bound variables are all type variables that are declared in the expr. * They only have meaning inside that expr, and can only be used in it. @@ -427,7 +255,8 @@ TVM_DLL tvm::Array FreeTypeVars(const Type& t, const Module& mod); */ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get all bound type variables from type t. +/*! + * \brief Get all bound type variables from type t. * * Bound variables are all type variables that are declared in the type. * They only have meaning inside that type, and can only be used in it. @@ -439,7 +268,8 @@ TVM_DLL tvm::Array BoundTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); -/*! \brief Get all type variables in expression expr. +/*! + * \brief Get all type variables in expression expr. * * \param expr the expression. * \param mod the module. @@ -448,7 +278,8 @@ TVM_DLL tvm::Array BoundTypeVars(const Type& t, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); -/*! \brief Get all type variables in type t. +/*! + * \brief Get all type variables in type t. * * \param t the type. * \param mod the module. @@ -465,32 +296,39 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); * For example, this pass should turn `let a = 1 in 2` into `2`, * as the value of the expression does not depend on a. * - * As another example, `let a = 1 in a` will be optimized into 1. + * As another example, `let a = 1 in a` will be optimized into 1, + * if the flag is turned on. * * \param e the expression to optimize. + * \param inline_once whether or not to inline binding used one. * * \return the optimized expression. */ -TVM_DLL Expr DeadCodeElimination(const Expr& e); +TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); /*! * \brief Fold constant expressions. + * * \param expr the expression to be optimized. + * * \return The optimized expression. */ TVM_DLL Expr FoldConstant(const Expr& expr); /*! * \brief Fuse operations into expr into seperate functions. + * * \param expr The expression. * \param fuse_opt_level Optimization level. * \param mod the module. + * * \return The optimized expression. */ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * * \param expr The expression. * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite * rule function. @@ -500,84 +338,77 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * * \param expr The expression. * \param rewrite_func The rewrite func that will apply to all operators. * \param fcontext Additional callback to provide context argument for each call node. * \param fmulti_ref_trigger Transformation function to be called when * an Expr consumed by multiple callers. + * * \return The rewritten expression. */ TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); /*! * \brief Rewrite the annotated program. + * * \param expr The expression. * \param fallback_device The fallback device which is the default device for * operators without annotation. + * * \return The updated program. */ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); /*! * \brief Collect the device mapping information of each expression. + * * \param expr The expression. + * * \return The device mapping. */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); -/*! \brief A hashing structure in the style of std::hash. */ -struct StructuralHash { - /*! \brief Hash a Relay type. - * - * Implements structural hashing of a Relay type. - * - * \param type the type to hash. - * - * \return the hash value. - */ - size_t operator()(const Type& type) const; - - /*! \brief Hash a Relay expression. - * - * Implements structural hashing of a Relay expression. - * - * \param expr the expression to hash. - * - * \return the hash value. - */ - size_t operator()(const Expr& expr) const; -}; +/*! + * \brief Collect the device anntation operators. + * + * \param expr The expression. + * + * \return The annotated expression to device type mapping for annotation ops. + */ +TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); -/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). +/*! + * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). * * It will turn an expression that is in a graph form (with sharing implicit), * to an expression with explicit sharing (A-Normal Form). * * The scope of the root expression is the global scope. - + * * The scope of any non root expression is the least common ancestor of all it's scope. * * Values are ordered by post-DFS order in each scope. * - * \param e the expression to observably share - * + * \param e the expression to observably share. * \param mod The module used for referencing global functions, can be * None. * - * \return expression in A-Normal Form + * \return expression in A-Normal Form. */ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); -/*! \brief Remove let binding and directly share via pointer instead. +/*! + * \brief Remove let binding and directly share via pointer instead. * * It will remove all let binding, * and turn all of the variable bound by let into direct pointer reference. @@ -588,18 +419,72 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); */ TVM_DLL Expr ToGraphNormalForm(const Expr& e); -/*! \brief Aggressive constant propagation/constant folding/inlining. +/*! + * \brief Finds cases that the given match expression does not catch, if any. + * + * \param match the match expression to test + * + * \param mod The module used for accessing global type var definitions, can be None. + * + * \return Returns a list of cases (as patterns) that are not handled by the match + * expression. + */ +TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); + +/*! + * \brief Aggressive constant propagation/constant folding/inlining. * It will do as much computation in compile time as possible. * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). * As a side effect, code size will explode. + * + * \param e the expression + * \param mod the module + * + * \return the optimized expression. + */ +TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); + +/*! + * \brief Bind the free variables to a Relay expression. + * + * \param expr The expression. + * \param bind_map The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. */ -Expr PartialEval(const Expr& e); +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); + +/*! \brief A hashing structure in the style of std::hash. */ +struct StructuralHash { + /*! \brief Hash a Relay type. + * + * Implements structural hashing of a Relay type. + * + * \param type the type to hash. + * + * \return the hash value. + */ + size_t operator()(const Type& type) const; + + /*! \brief Hash a Relay expression. + * + * Implements structural hashing of a Relay expression. + * + * \param expr the expression to hash. + * + * \return the hash value. + */ + size_t operator()(const Expr& expr) const; +}; namespace vm { -/*! \brief Compile a module, and construct the virtual machine. +/*! + * \brief Compile a module, and construct the virtual machine. * * \param mod The module to compile. + * * \return The constructed virtual machine. */ runtime::vm::VirtualMachine CompileModule(const Module& mod); diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h new file mode 100644 index 000000000000..04b4e64dc9c3 --- /dev/null +++ b/include/tvm/relay/transform.h @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/transform.h + * + * This file implements a pass manager. The pass manager manages a sequence + * of Relay-to-Relay transformation passes over a particlar unit of AST. The + * design is largely inspired from LLVM's pass manager and modern deep learning + * frameworks that perform tensor->tensor transformations. + * + * The responsibilities of a traditional compiler pass manager usually involves: + * - Organizing the execution order of optimization passes though not + * necessarily in the optimal sequence. + * - Collecting required analysis information and keep them up-to-date. + * - Reducing the effort required to implement new passes for compiler + * developers, etc. + * + * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * different granularity, i.e. module level, function level, and even sequential + * passe that contains a host of passes. + * + * However, we also extend the functionality of the traditional pass manager + * with the consideration of requirements/convention from deep learning + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * manager performs the Relay.Module -> Relay.Module transformation. All + * different types of passes, including the sequential-level pass object, are + * essentially pass objects. This design, therefore, effectively provides users + * a consistent and convenient interface, i.e. Pass, to play with. It offers a + * means to ease the development and testing of Relay passes. For example, with + * the pass manager, external users will be able to have custom passes correctly + * scheduled without having to modify a single handcrafted pass order. + * + * In the future we need to describe constraints between passes. For example, + * we may want to preserve dependencies between different passes and validate + * them on the completion of a certain pass. + * + * We also need to store side information and import the error reporting system. + */ +#ifndef TVM_RELAY_TRANSFORM_H_ +#define TVM_RELAY_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +/* + * \brief The context of pass. + */ +class PassContext; + +/*! + * \brief PassContextNode contains the information that a pass can rely on, + * such as analysis results. + */ +class PassContextNode : public RelayNode { + public: + /*! + * \brief The error reporter used to notify users why an optimization fails. + */ + ErrorReporter err_reporter; + + /*! \brief The default optimization level. */ + int opt_level{2}; + + /*! \brief CPU is the default fallback device for heterogeneous execution. */ + int fallback_device{static_cast(kDLCPU)}; + + /*! \brief The list of required passes. */ + tvm::Array required_pass; + /*! \brief The list of disabled passes. */ + tvm::Array disabled_pass; + + PassContextNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("fallback_device", &fallback_device); + v->Visit("required_pass", &required_pass); + v->Visit("disabled_pass", &disabled_pass); + } + + static constexpr const char* _type_key = "relay.PassContext"; + TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); +}; + +/*! + * \brief PassContext that is used to configure the pass behavior. + * + * \code + * + * auto new_ctx = PassContext::Create(); + * ctx->opt_level = 2; + * ctx->fallback_device = kDLCPU; + * With scope(ctx); + * // pass context in effect. + * + * \endcode + */ +class PassContext : public NodeRef { + public: + PassContext() {} + explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {} + /*! + * \brief const accessor. + * \return const access pointer. + */ + const PassContextNode* operator->() const { + CHECK(node_.get() != nullptr); + return static_cast(node_.get()); + } + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + PassContextNode* operator->() { + CHECK(node_.get() != nullptr); + return static_cast(node_.get()); + } + /*! + * \brief Construct a PassContext containing the default configurations. + * \return The new PassContext. + */ + TVM_DLL static PassContext Create(); + /*! + * \brief Get the default pass context in the current scope. + * \return The pass context. + */ + TVM_DLL static PassContext Current(); + + // accessor. + using ContainerType = PassContextNode; + class Internal; + + private: + // The entry of a pass context scope. + TVM_DLL void EnterWithScope(); + // The exit of a pass context scope. + TVM_DLL void ExitWithScope(); + + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class tvm::With; +}; + +/* + * \brief The meta data of a pass. + * + * PassInfo can be extended conveniently in the future if more meta information + * is needed. + */ +class PassInfo; + +/*! + * \brief PassInfoNode contains meta data that will be used to help optimization + * and analysis. + */ +class PassInfoNode : public RelayNode { + public: + /*! \brief The minimal optimization level that this pass will be enabled. */ + int opt_level; + + /*! \brief The name of an optimization/analysis pass. */ + std::string name; + + /*! \brief The passes that are required to perform the current pass. */ + tvm::Array required; + + PassInfoNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("opt_level", &opt_level); + v->Visit("name", &name); + v->Visit("required", &required); + } + + TVM_DLL static PassInfo make(int opt_level, + std::string name, + tvm::Array required); + + static constexpr const char* _type_key = "relay.PassInfo"; + TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode); +}; + +TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode) + +class Pass; + +/*! + * \brief PassNode is the base type of differnt types of optimization passes. + * It is designed as a pure class and implemented by different pass subclasses + * at different granularity of Relay nodes. + */ +class PassNode : public RelayNode { + public: + /*! + * \brief Get the pass information/meta data. */ + virtual PassInfo Info() const = 0; + + /*! + * \brief Transform mod using the default PassContext in the current scope. + * + * \param mod The module that an optimization pass runs on. + * + * \return The transformed module. + */ + Module operator()(const Module& mod) const { + return this->operator()(mod, PassContext::Current()); + } + + /*! + * \brief Transform mod using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ + virtual Module operator()(const Module& mod, + const PassContext& pass_ctx) const = 0; + + void VisitAttrs(tvm::AttrVisitor* v) override {} + + static constexpr const char* _type_key = "relay.Pass"; + TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode); +}; + +class Pass : public NodeRef { + public: + /*! + * \brief Transform mod using the default PassContext in the current scope. + * + * \param mod The module that an optimization pass runs on. + * + * \return The transformed module. + */ + Module operator()(const Module& mod) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod); + } + /*! + * \brief Transform mod using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ + Module operator()(const Module& mod, + const PassContext& pass_ctx) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod, pass_ctx); + } + + TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode); +}; + +class SequentialNode; + +class Sequential : public Pass { + public: + /*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param pass_info The pass metadata. + */ + TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + + /*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param name The name of a sequential pass. It's defaulted to "sequential". + * This allows users to only provide a list of passes and execute them + * under a given context. + */ + TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + + Sequential() = default; + explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {} + + const SequentialNode* operator->() const; + using ContainerType = Sequential; +}; + +/* + * \brief Create a module pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the module pass. + * \param name The name of the module pass. + * \param required The list of the passes that the module pass is dependent on. + * + * \return The created module pass. + */ +Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/* + * \brief Create a function pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the function pass. + * \param name The name of the function pass. + * \param required The list of the passes that the function pass is dependent on. + * + * \return The created function pass. + */ +TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< + Function(Function, Module, PassContext)>& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +/*! \brief Remove expressions which does not effect the program result. + * + * It will remove let bindings which are not referenced, + * and inline let bindings that are only used once. + * + * For example, this pass should turn `let a = 1 in 2` into `2`, + * as the value of the expression does not depend on a. + * + * As another example, `let a = 1 in a` will be optimized into 1. + * + * \param inline_once whether or not to inline binding used one. + * + * \return the pass. + */ +TVM_DLL Pass DeadCodeElimination(bool inline_once = false); + +/*! + * \brief Fold constant expressions. + * + * \return The pass. + */ +TVM_DLL Pass FoldConstant(); + +/*! + * \brief Fuse operations into expr into seperate functions. + * + * \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context. + * + * \return The pass. + */ +TVM_DLL Pass FuseOps(int fuse_opt_level = -1); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * + * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite + * rule function. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The pass. + */ +TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function + fmulti_ref_trigger = nullptr); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The pass. + */ +TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +/*! + * \brief Rewrite the annotated program. + * + * \param fallback_device The fallback device which is the default device for + * operators without annotation. + * + * \return The pass. + */ +TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); + +/*! + * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). + * + * It will turn an expression that is in a graph form (with sharing implicit), + * to an expression with explicit sharing (A-Normal Form). + * + * The scope of the root expression is the global scope. + * + * The scope of any non root expression is the least common ancestor of all it's scope. + * + * Values are ordered by post-DFS order in each scope. + * + * \return The pass. + */ +TVM_DLL Pass ToANormalForm(); + +/*! + * \brief Remove let binding and directly share via pointer instead. + * + * It will remove all let binding, + * and turn all of the variable bound by let into direct pointer reference. + * + * \return the expression in graph normal form. + */ +TVM_DLL Pass ToGraphNormalForm(); + +/*! + * \brief Aggressive constant propagation/constant folding/inlining. + * + * It will do as much computation in compile time as possible. + * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). + * As a side effect, code size will explode. + * + * \return the optimized expression. + */ +TVM_DLL Pass PartialEval(); + +/*! + * \brief Simplify certain operators during inference. For example, batch norm + * will be unpacked into a number of simplified operators. + * + * \return The Pass. + */ +TVM_DLL Pass SimplifyInference(); + +/*! + * \brief Infer the type of an expression. + * + * The result of type checking is a new expression with unambigous + * type information filled in, as well as it's checked type field + * populated with the result type. + * + * \return The pass. + */ +TVM_DLL Pass InferType(); + +/*! + * \brief Search and eliminate common subexpression. For example, if there are + * two expressions evaluated to an identical value, a single variable is created + * and these two expressions are replaced by this variable. + * + * \param fskip The callback argument that allows to skip certain expressions. + * + * \return The pass. + */ +TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); + +/*! + * \brief Combine parallel 2d convolutions into a single convolution if the + * number of branches of this conv2d operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); + +/*! + * \brief Backward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass BackwardFoldScaleAxis(); + +/*! + * \brief Forward fold axis scaling into weights of conv/dense operators. + * + * \return The pass. + */ +TVM_DLL Pass ForwardFoldScaleAxis(); + +/*! + * \brief A sequential pass that executes ForwardFoldScaleAxis and + * BackwardFoldScaleAxis passes. + * + * \return The pass. + */ +TVM_DLL Pass FoldScaleAxis(); + +/*! + * \brief Canonicalize some operators to the simplified operators. For example, + * bias_add can be canonicalized to expand_dims and broadcast_add. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeOps(); + +/*! + * \brief Alternate the layouts of operators or replace primitive operators + * with other expressions. + * + * \return The pass. + */ +TVM_DLL Pass AlterOpLayout(); + +/*! + * \brief Canonicalize cast expressions to make operator fusion more efficient. + * + * \return The pass. + */ +TVM_DLL Pass CanonicalizeCast(); + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORM_H_ diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f992e87ad100..ba2c0d2291b6 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -114,6 +114,8 @@ typedef enum { // The following section of code is used for non-reserved types. kExtReserveEnd = 64U, kExtEnd = 128U, + // The rest of the space is used for custom, user-supplied datatypes + kCustomBegin = 129U, } TVMTypeCode; /*! @@ -185,7 +187,7 @@ TVM_DLL void TVMAPISetLastError(const char* msg); /*! * \brief return str message of the last error * all function in this file will return 0 when success - * and -1 when an error occured, + * and -1 when an error occurred, * TVMGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 9fcefcbbe4b1..82b3dd469541 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -60,6 +60,29 @@ namespace tvm { class Integer; namespace runtime { + +/*! + * \brief Runtime utility for getting custom type name from code + * \param type_code Custom type code + * \return Custom type name + */ +TVM_DLL std::string GetCustomTypeName(uint8_t type_code); + +/*! + * \brief Runtime utility for checking whether custom type is registered + * \param type_code Custom type code + * \return Bool representing whether type is registered + */ +TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); + +/*! + * \brief Runtime utility for parsing string of the form "custom[]" + * \param s String to parse + * \param scan pointer to parsing pointer, which is scanning across s + * \return type code of custom type parsed + */ +TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); + // forward declarations class TVMArgs; class TVMArgValue; @@ -939,7 +962,11 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } - os << TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + os << "custom[" << GetCustomTypeName(t.code) << "]"; + } else { + os << TypeCode2Str(t.code); + } if (t.code == kHandle) return os; os << static_cast(t.bits); if (t.lanes != 1) { @@ -960,7 +987,11 @@ inline std::string TVMType2String(TVMType t) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } - repr += TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + repr += "custom[" + GetCustomTypeName(t.code) + "]"; + } else { + repr += TypeCode2Str(t.code); + } if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); if (t.lanes != 1) { @@ -994,6 +1025,8 @@ inline TVMType String2TVMType(std::string s) { t.bits = 1; t.lanes = 1; return t; + } else if (s.substr(0, 6) == "custom") { + t.code = ParseCustomDatatype(s, &scan); } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 8911ad499e4c..028a5ff9d1ad 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -56,13 +56,14 @@ enum class Opcode { InvokeClosure = 3U, InvokePacked = 4U, AllocTensor = 5U, - AllocDatatype = 6U, - AllocClosure = 7U, - GetField = 8U, - If = 9U, - Select = 10U, - LoadConst = 11U, - Goto = 12U + AllocTensorReg = 6U, + AllocDatatype = 7U, + AllocClosure = 8U, + GetField = 9U, + If = 10U, + Select = 11U, + LoadConst = 12U, + Goto = 13U }; /*! \brief A single virtual machine instruction. @@ -83,11 +84,19 @@ struct Instruction { union { struct /* AllocTensor Operands */ { + /*! \brief The number of dimensions. */ + uint32_t ndim; + /*! \brief The shape of tensor. */ + int64_t* shape; + /*! \brief The datatype of tensor to be allocated. */ + DLDataType dtype; + } alloc_tensor; + struct /* AllocTensorReg Operands */ { /*! \brief The register to read the shape out of. */ RegName shape_register; /*! \brief The datatype of tensor to be allocated. */ DLDataType dtype; - }; + } alloc_tensor_reg; struct /* InvokeClosure Operands */ { /*! \brief The register containing the closure. */ RegName closure; @@ -192,13 +201,20 @@ struct Instruction { */ static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args); - /*! \brief Construct an allocate tensor instruction. + /*! \brief Construct an allocate tensor instruction with constant shape. + * \param shape The shape of the tensor. + * \param dtype The dtype of the tensor. + * \param dst The destination register. + * \return The allocate tensor instruction. + */ + static Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst); + /*! \brief Construct an allocate tensor instruction with register. * \param shape_register The register containing the shape. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst); /*! \brief Construct an allocate datatype instruction. * \param tag The datatype tag. * \param num_fields The number of fields for the datatype. diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 774d7cd9a40a..659b42aa1afa 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -102,8 +102,8 @@ class Stage : public NodeRef { */ EXPORT Stage& bind(IterVar ivar, IterVar thread_ivar); /*! - * \brief Set predicate under which store to the array can be performed. - * Use this when there are duplicated threads doing the same store and we only + * \brief Set the predicate to determine whether a store to the array should be performed. + * Use this when there are multiple threads performing the same store and we only * need one of them to do the store. * * \note This is a dangerous scheduling primitive that can change behavior of program. diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index e61ce6634bd3..b5ca6eb4358b 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -67,6 +67,11 @@ class TensorIntrinNode : public Node { * When it is a constant, it means we can only take data in that shape. */ Array buffers; + /*! \brief List of scalar variables, used in body. These placeholders + * will be bound to expressions passed in when the TensorIntrin is called + * from a TensorComputeOp. + */ + Array scalar_params; /*! \brief The normal statement to execute the intrinsic */ Stmt body; /*! @@ -87,6 +92,7 @@ class TensorIntrinNode : public Node { v->Visit("op", &op); v->Visit("inputs", &inputs); v->Visit("buffers", &buffers); + v->Visit("scalar_params", &scalar_params); v->Visit("body", &body); v->Visit("reduce_init", &reduce_init); v->Visit("reduce_update", &reduce_update); @@ -96,6 +102,7 @@ class TensorIntrinNode : public Node { Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); @@ -134,22 +141,29 @@ class TensorIntrinCallNode : public Node { Array tensors; /*! \brief regions of input tensors */ Array regions; + + /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis */ Array reduce_axis; + /*! \brief scalar expression inputs */ + Array scalar_inputs; + void VisitAttrs(AttrVisitor* v) final { v->Visit("intrin", &intrin); v->Visit("tensors", &tensors); v->Visit("regions", ®ions); v->Visit("reduce_axis", &reduce_axis); + v->Visit("scalar_inputs", &scalar_inputs); } static TensorIntrinCall make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis); + Array reduce_axis, + Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); diff --git a/jvm/README.md b/jvm/README.md index 626a87c84cf8..c5996313447b 100644 --- a/jvm/README.md +++ b/jvm/README.md @@ -30,6 +30,7 @@ This folder contains the Java interface for TVM runtime. It brings TVM runtime t - JDK 1.6+. Oracle JDK and OpenJDK are well tested. - Maven 3 for build. +- LLVM (TVM4J need LLVM support. Please refer to [build-the-shared-library](https://docs.tvm.ai/install/from_source.html#build-the-shared-library) for how to enable LLVM support.) ### Modules diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index 75054e892d8e..773bc63b7dad 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -60,7 +60,7 @@ NNVM_DLL void NNAPISetLastError(const char* msg); /*! * \brief return str message of the last error * all function in this file will return 0 when success - * and -1 when an error occured, + * and -1 when an error occurred, * NNGetLastError can be called to retrieve the error * * this function is threadsafe and can be called by different thread diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 976ad929f496..ad328c30312a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -136,6 +136,17 @@ using FInferType = FInferNodeEntryAttr; */ using TIsBackward = bool; +/*! + * \brief Whether this op is a ghost node. + * If TIsGhost is true: + * - The node with this op will not be visible in the indexed graph. + * + * \note Register under "TIsGhost" + * This enables shape/type inference for backward nodes when + * fusion is present. + */ +using TIsGhost = bool; + /*! * \brief Get possible inplace options. * This function enables optimization to reuse memory of inputs in output. diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 0c1bc5047870..a49e9741a901 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -171,7 +171,7 @@ def _update_shape_dtype(shape, dtype, params): shape.update({k : v.shape for k, v in params.items()}) if isinstance(dtype, str): for k, v in params.items(): - if v.dtype != dtype: + if v.dtype != dtype and v.shape: raise ValueError( "%s: dtype not expected %s vs %s" % (k, dtype, v.dtype)) else: diff --git a/nnvm/python/nnvm/frontend/caffe2.py b/nnvm/python/nnvm/frontend/caffe2.py index 2b3ff5a27e01..f951db66b5a6 100644 --- a/nnvm/python/nnvm/frontend/caffe2.py +++ b/nnvm/python/nnvm/frontend/caffe2.py @@ -411,7 +411,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from Caffe2 operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/nnvm/python/nnvm/frontend/common.py b/nnvm/python/nnvm/frontend/common.py index 610546d1973b..0e09a2c43323 100644 --- a/nnvm/python/nnvm/frontend/common.py +++ b/nnvm/python/nnvm/frontend/common.py @@ -58,7 +58,7 @@ def __call__(self, inputs, attrs, *args): class AttrConverter(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -72,12 +72,12 @@ class AttrConverter(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in nnvm. Log warnings. ignores : list diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index 7af8cf8833dd..f647a644bd2b 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -180,7 +180,6 @@ def _convert_convolution(insym, keras_layer, symtab): else: kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape weight = weightList[0].transpose([3, 2, 0, 1]) - dilation = [1, 1] if isinstance(keras_layer.dilation_rate, (list, tuple)): dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] else: diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 77671225aa3e..6f6bfc87ea8a 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -269,7 +269,7 @@ def _crop_like(inputs, attrs): raise tvm.error.OpAttributeUnimplemented( 'Center crop is not supported in operator crop_like.') if len(inputs) < 2: - raise RuntimeError("Only support crop_like pattern.") + raise tvm.error.OpAttributeUnimplemented("Only support crop_like pattern.") new_attrs["axis"] = [2, 3] return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs) diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py index eb78b7845c23..b5e294b97fb1 100644 --- a/nnvm/python/nnvm/frontend/onnx.py +++ b/nnvm/python/nnvm/frontend/onnx.py @@ -27,6 +27,13 @@ __all__ = ['from_onnx'] +def onnx_storage_order2layout(storage_order): + if storage_order not in (0, 1): + raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') + + return 'NCHW' if sotrage_order == 0 else 'NHWC' + + class OnnxOpConverter(object): """ A helper class for holding onnx op converters. """ @@ -207,8 +214,38 @@ def _impl_v1(cls, inputs, attr, params): class MaxPool(Pool): + """ Operator converter for MaxPool + """ name = 'max_pool' + @classmethod + def _impl_v8(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + # TODO(higumachan): make sure ceil_mode in onnx, and layout? + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, attr, params) + + @classmethod + def _impl_v10(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + 'ceil_mode': 'ceil_mode' + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + custom_check=dimension_constraint())(inputs, attr, params) class Mul(Elemwise): name = 'mul' @@ -830,6 +867,19 @@ def from_onnx(self, graph, opset): else: self._num_input += 1 self._nodes[i_name] = _sym.Variable(name=i_name) + # get list of unsupported ops + convert_map = _get_convert_map(opset) + unsupported_ops = set() + for node in graph.node: + op_name = node.op_type + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) + if unsupported_ops: + msg = 'The following operators are not supported for frontend ONNX: ' + msg += ', '.join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type @@ -913,7 +963,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from onnx operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 2f91cad8143a..7b4147155d93 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -1188,7 +1188,7 @@ def __init__(self): self._input_shapes = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): - """Construct nnvm nodes from tensorflow graph definition - GraphDef. + """Construct nnvm nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to NNVM. Some of the assumptions listed below. @@ -1197,7 +1197,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): -> All Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Graph should be frozen with add_shapes=True. - Or user can pass input shape dictionaly optionally. + Or user can pass input shape dictionary optionally. -> DecodeJpeg, ResizeBilinear: These are dummy operators. Hence user should handle preprocessing outside. -> CheckNumerics: No implementation as of now for this. @@ -1214,6 +1214,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : nnvm.sym.Symbol @@ -1547,7 +1550,7 @@ def _convert_rnn_operator(self, op_name, inputs, def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to nnvm operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters @@ -1599,7 +1602,7 @@ def _fix_extranodes(self, op_name, attr, inputs): return inputs def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): - """ Load tensorflow graph which is a python tensorflow graph object into nnvm graph. + """Load tensorflow graph which is a python tensorflow graph object into nnvm graph. The companion parameters will be handled automatically. Parameters @@ -1607,6 +1610,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): graph : GraphDef object Tensorflow GraphDef + layout : target layout to be used (Optional) + NCHW only supported now to enable NHWC models on GPU. + + shape : Dictionary of input dimensions (Optional) + Graph level input shape dictionary. + + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- sym : nnvm.Symbol diff --git a/nnvm/python/nnvm/testing/__init__.py b/nnvm/python/nnvm/testing/__init__.py index 44b8529821d0..41bcf83eb511 100644 --- a/nnvm/python/nnvm/testing/__init__.py +++ b/nnvm/python/nnvm/testing/__init__.py @@ -13,5 +13,4 @@ from . import inception_v3 from . import dcgan from . import dqn -from . import yolo_detection from . import check_computation diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 92ff98618ec8..e2d6d36020f1 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -76,6 +76,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] (const NodePtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; CHECK_LT(nodes_.size(), std::numeric_limits::max()); uint32_t nid = static_cast(nodes_.size()); CHECK(n); @@ -103,8 +105,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { inputs_rptr.push_back(input_entries_.size()); // control deps for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; auto it = node2index_.find(nptr.get()); - CHECK(it != node2index_.end() && it->first == nptr.get()); + CHECK(it != node2index_.end()) << "control dep not found in graph"; control_deps_.push_back(it->second); } control_rptr.push_back(control_deps_.size()); diff --git a/nnvm/tests/python/compiler/test_top_level2.py b/nnvm/tests/python/compiler/test_top_level2.py index b25feb74793f..3c5651578b64 100644 --- a/nnvm/tests/python/compiler/test_top_level2.py +++ b/nnvm/tests/python/compiler/test_top_level2.py @@ -305,7 +305,7 @@ def test_upsampling_nearest_neighbor(): data = tvm.nd.array(a_np) m.run(x=data) out = m.get_output(0, tvm.nd.empty(oshape, dtype)) - b_np = topi.testing.upsampling_python(a_np, scale, "NCHW") + b_np = topi.testing.upsampling_python(a_np, (scale, scale), "NCHW") tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5) def test_upsampling_bilinear(): diff --git a/nnvm/tests/python/frontend/coreml/test_forward.py b/nnvm/tests/python/frontend/coreml/test_forward.py index 679afe4e86bc..7a9f294f4359 100644 --- a/nnvm/tests/python/frontend/coreml/test_forward.py +++ b/nnvm/tests/python/frontend/coreml/test_forward.py @@ -195,7 +195,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode): a_np = np.full(input_dim, 1, dtype=dtype) if mode == 'NN': - b_np = topi.testing.upsampling_python(a_np, scale) + b_np = topi.testing.upsampling_python(a_np, (scale, scale)) else: new_h = input_dim[2] * scale new_w = input_dim[3] * scale diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py index 7f45a6149efc..4e62ff2e1f33 100644 --- a/nnvm/tests/python/frontend/darknet/test_forward.py +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -27,8 +27,8 @@ from tvm.contrib.download import download_testdata download_testdata.__test__ = False from nnvm import frontend -from nnvm.testing.darknet import LAYERTYPE -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import LAYERTYPE +from tvm.relay.testing.darknet import __darknetffi__ import nnvm.compiler DARKNET_LIB = 'libdarknet2.0.so' diff --git a/nnvm/tests/python/frontend/mxnet/test_forward.py b/nnvm/tests/python/frontend/mxnet/test_forward.py index db5534daee1a..446ebebbfc5a 100644 --- a/nnvm/tests/python/frontend/mxnet/test_forward.py +++ b/nnvm/tests/python/frontend/mxnet/test_forward.py @@ -137,7 +137,7 @@ def test_forward_fc_flatten(): def test_forward_clip(): data = mx.sym.var('data') - data = mx.sym.concat(data, -data, dim=1) # negative part explicity + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly mx_sym = mx.sym.clip(data, a_min=0, a_max=1) verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 941a275a8045..3365b0f25fb1 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -405,7 +405,7 @@ def _test_upsample_nearest(): y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0]) in_array = np.random.uniform(size=in_shape).astype(np.float32) - out_array = topi.testing.upsampling_python(in_array, scale, "NCHW") + out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW") graph = helper.make_graph([y], 'upsample_nearest_test', diff --git a/nnvm/tutorials/from_darknet.py b/nnvm/tutorials/from_darknet.py index 857ef46015cd..d2ab647da1b3 100644 --- a/nnvm/tutorials/from_darknet.py +++ b/nnvm/tutorials/from_darknet.py @@ -33,8 +33,8 @@ import nnvm import nnvm.frontend.darknet -import nnvm.testing.yolo_detection -import nnvm.testing.darknet +import tvm.relay.testing.yolo_detection +import tvm.relay.testing.darknet import matplotlib.pyplot as plt import numpy as np import tvm @@ -42,7 +42,7 @@ from ctypes import * from tvm.contrib.download import download_testdata -from nnvm.testing.darknet import __darknetffi__ +from tvm.relay.testing.darknet import __darknetffi__ # Model name MODEL_NAME = 'yolov3' @@ -104,7 +104,7 @@ test_image + '?raw=true' img_path = download_testdata(img_url, test_image, "data") -data = nnvm.testing.darknet.load_image(img_path, netw, neth) +data = tvm.relay.testing.darknet.load_image(img_path, netw, neth) ###################################################################### # Execute on TVM Runtime # ---------------------- @@ -153,12 +153,12 @@ # do the detection and bring up the bounding boxes thresh = 0.5 nms_thresh = 0.45 -img = nnvm.testing.darknet.load_image_color(img_path) +img = tvm.relay.testing.darknet.load_image_color(img_path) _, im_h, im_w = img.shape -dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, +dets = tvm.relay.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh, 1, tvm_out) last_layer = net.layers[net.n - 1] -nnvm.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) +tvm.relay.testing.yolo_detection.do_nms_sort(dets, last_layer.classes, nms_thresh) coco_name = 'coco.names' coco_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + coco_name + '?raw=true' @@ -172,6 +172,6 @@ names = [x.strip() for x in content] -nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) +tvm.relay.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes) plt.imshow(img.transpose(1, 2, 0)) plt.show() diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ce6f0602a572..5765eed0ad8b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -38,12 +38,13 @@ from . import hybrid from . import testing from . import error +from . import datatype from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, opengl, ext_dev -from ._ffi.runtime_ctypes import TypeCode +from ._ffi.runtime_ctypes import TypeCode, TVMType from ._ffi.ndarray import TVMContext from ._ffi.function import Function from ._ffi.base import TVMError, __version__ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 4ede33a63936..72cff1a10ead 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -91,6 +91,13 @@ def __init__(self, type_str): self.type_code = 4 bits = 64 head = "" + elif head.startswith("custom"): + low, high = head.find('['), head.find(']') + if not low or not high or low >= high: + raise ValueError("Badly formatted custom type string %s" % type_str) + type_name = head[low + 1:high] + self.type_code = _api_internal._datatype_get_type_code(type_name) + head = head[high+1:] else: raise ValueError("Do not know how to handle type %s" % type_str) bits = int(head) if head else bits @@ -100,7 +107,12 @@ def __init__(self, type_str): def __repr__(self): if self.bits == 1 and self.lanes == 1: return "bool" - x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) + if self.type_code in TVMType.CODE2STR: + type_name = TVMType.CODE2STR[self.type_code] + else: + type_name = "custom[%s]" % \ + _api_internal._datatype_get_type_name(self.type_code) + x = "%s%d" % (type_name, self.bits) if self.lanes != 1: x += "x%d" % self.lanes return x diff --git a/python/tvm/api.py b/python/tvm/api.py index 66fa4fa30e90..d88f06170543 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -319,7 +319,8 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): out_ndim, body.intrin, body.tensors, - body.regions) + body.regions, + body.scalar_inputs) else: if not isinstance(body, (list, tuple)): body = [body] diff --git a/python/tvm/arith.py b/python/tvm/arith.py index eda5cb825326..4c3c05f75796 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -32,21 +32,21 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_node +@register_node("arith.IntervalSet") class IntervalSet(IntSet): - """Represent set of continuous interval""" - def min(self): - """get the minimum value""" - return _api_internal._IntervalSetGetMin(self) - - def max(self): - """get the maximum value""" - return _api_internal._IntervalSetGetMax(self) + """Represent set of continuous interval [min_value, max_value] + Parameters + ---------- + min_value : Expr + The minimum value in the interval. -@register_node -class StrideSet(IntSet): - """Represent set of strided integers""" + max_value : Expr + The maximum value in the interval. + """ + def __init__(self, min_value, max_value): + self.__init_handle_by_constructor__( + _make_IntervalSet, min_value, max_value) @register_node("arith.ModularSet") @@ -114,6 +114,7 @@ def __init__(self): self._modular_set = _mod("modular_set") self._rewrite_simplify = _mod("rewrite_simplify") self._canonical_simplify = _mod("canonical_simplify") + self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") def const_int_bound(self, expr): @@ -176,6 +177,24 @@ def canonical_simplify(self, expr): """ return self._canonical_simplify(expr) + def int_set(self, expr, dom_map): + """Compute a symbolic IntSet that covers expr for all values in dom_map. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + dom_map : Dict[Var, tvm.arith.IntSet] + The domain for variables to be relaxed. + + Returns + ------- + result : IntSet + The result. + """ + return self._int_set(expr, dom_map) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/python/tvm/autotvm/graph_tuner/__init__.py b/python/tvm/autotvm/graph_tuner/__init__.py new file mode 100644 index 000000000000..d590db0e7c48 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Autotvm graph tuner API.""" +from __future__ import absolute_import as _abs + +from . import _base +from . import base_graph_tuner + +from .base_graph_tuner import BaseGraphTuner +from .dynamic_programming_tuner import DPTuner +from .pbqp_tuner import PBQPTuner diff --git a/python/tvm/autotvm/graph_tuner/_base.py b/python/tvm/autotvm/graph_tuner/_base.py new file mode 100644 index 000000000000..83b9e06ba564 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/_base.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Helper functions and global data""" + + +RULE_OUT_NODE_NAMES = ["Tuple", "TupleGetItem", "batch_flatten", "transpose", "reshape", + "multibox_prior", "multibox_transform_loc", "where", + "non_max_suppression", "strided_slice"] + +# We set a large time to represent an invalid layout-transformation. +# This number is set to be 10e9 seconds to align with autotvm. +INVALID_LAYOUT_TIME = 10e9 diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py new file mode 100644 index 000000000000..0fbfc27310cb --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -0,0 +1,522 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-instance-attributes,too-many-branches,too-many-nested-blocks,invalid-name,unused-argument,unused-variable,no-member,no-value-for-parameter +"""Base class for graph tuner.""" +import logging +from abc import abstractmethod + +import numpy as np +import topi + +import tvm +from tvm import autotvm, relay +from tvm.autotvm.task import get_config +from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args +from tvm.autotvm.record import encode, load_from_file +from tvm.autotvm.measure import MeasureResult, MeasureInput + +from ... import target as _target +from .utils import is_input_node, get_in_nodes, get_out_nodes, has_multiple_inputs, \ + bind_inputs, expr2graph +from ._base import INVALID_LAYOUT_TIME + + +# Setup topi_op_name -> layout function +# NOTE: To add more ops, change the following dictionary. +OP2LAYOUT = { + "topi_nn_conv2d": topi.nn.conv2d_infer_layout, + "topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout, +} + + +@autotvm.template +def layout_transform(*args): + """Autotvm layout transform template.""" + args = deserialize_args(args) + cfg = get_config() + cfg.add_flop(-1) + data = args[0] + out = topi.layout_transform(*args) + sch = topi.generic.schedule_injective([out]) + return sch, [data, out] + + +class BaseGraphTuner(object): + """Class to search schedules considering both kernel execution time and + layout transformation time. + + Before creating a Graph Executor instance, schedule candidates for all kernels in + graph should be provided through tensor-level tuning. + """ + def __init__(self, graph, input_shapes, records, target_ops, + target, max_sch_num=20, dtype="float32", verbose=True, + log_file="graph_tuner.log", log_level=logging.DEBUG, + name="graph_tuner"): + """Create a GlobalTuner instance. Local schedule searching for all nodes with + target_op in the input graph and layout transformation benchmark need to be + executed before initialization. + + graph : tvm.relay.Expr.Function + Input graph + + input_shapes : dict of str to tuple. + Input shapes of graph + + records : str or iterator of (MeasureInput, MeasureResult) + Collection of kernel level tuning records. + If it is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + + target_ops : List of str + Target tuning operators. + + target : str or tvm.target + Compilation target. + + max_sch_num : int, optional + Maximum number of schedule candidates for each workload. + + dtype : str, optional + Data type. + + log_file : str, optional + graph tuner log file name + + name : str, optional + Name of global tuner. + """ + self._node_list = [] + self._layout_transform_perf_records = {} + self._layout_transform_interlayer_cost = {} + self._input_shapes = input_shapes + self._target_ops = [op.__name__ for op in target_ops] + + self._name = name + self._max_sch_num = max_sch_num + self._optimal_sch_dict = {} + self._records = records + self._dtype = dtype + if isinstance(target, str): + target = _target.create(target) + self._target = target + self._optimal_record_dict = {} + + # Set up logger + self._verbose = verbose + self._logger = logging.getLogger(name + "_logger") + need_file_handler = need_console_handler = True + for handler in self._logger.handlers: + if handler.__class__.__name__ == 'FileHandler': + need_file_handler = False + if handler.__class__.__name__ == 'StreamHandler': + need_console_handler = False + self._log_level = log_level + self._log_file = log_file + self._formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + self._logger.setLevel(log_level) + if need_file_handler: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(self._formatter) + self._logger.addHandler(file_handler) + if self._verbose and need_console_handler: + console_handler = logging.StreamHandler() + console_handler.setFormatter(self._formatter) + self._logger.addHandler(console_handler) + self._logger.setLevel(log_level) + self._logger.propagate = False + + # Generate workload and schedule dictionaries. + if isinstance(graph, relay.expr.Function): + node_dict = {} + graph = bind_inputs(graph, input_shapes, dtype) + expr2graph(graph, self._target_ops, node_dict, self._node_list) + else: + raise RuntimeError("Unsupported graph type: %s" % str(type(graph))) + + self._graph = graph + self._in_nodes_dict = get_in_nodes(self._node_list, self._target_ops, input_shapes.keys()) + self._out_nodes_dict = get_out_nodes(self._in_nodes_dict) + self._fetch_cfg() + + # Setup infer_layout for elemwise-like nodes + # Note: graph tuner currently only supports tuning of single input and single output + # op as target op, such as conv2d, dense and conv2d_transpose. In this case, we can + # reuse infer_layout function from target ops for elemwise-like nodes. The behavior + # is to modify the first tensor shape of input workload to the output shape of + # elemwise-like node, and use infer_layout function from input op to generate layouts. + input_names = self._input_shapes.keys() + for idx in sorted(self._in_nodes_dict.keys()): + if has_multiple_inputs(self._node_list, idx, input_names): + node_entry = self._node_list[idx] + node_entry["topi_op"] = [] + node_entry["workloads"] = [] + for input_idx in self._in_nodes_dict[idx]: + input_node = self._node_list[input_idx] + if not is_input_node(input_node, input_names): + input_topi_op = input_node["topi_op"][0] + node_entry["topi_op"].append(input_topi_op) + # Only replace the first input tensor + input_workload = input_node["workloads"][0] + first_tensor = input_workload[1] + dtype = first_tensor[-1] + new_shape = tuple([val.value for val in node_entry["types"][0].shape]) + actual_workload = (input_workload[0],) + \ + ((new_shape + (dtype,)),) + input_workload[2:] + node_entry["workloads"].append(actual_workload) + if "record_candidates" not in node_entry: + node_entry["record_candidates"] = input_node["record_candidates"] + else: + node_entry["topi_op"].append(None) + node_entry["workloads"].append(None) + + + def _fetch_cfg(self): + """Read and pre-process input schedules.""" + if isinstance(self._records, str): + records = load_from_file(self._records) + else: + records = self._records + cfg_dict = {} + for record in records: + in_measure, _ = record + workload = in_measure.task.workload + if workload not in cfg_dict: + cfg_dict[workload] = [] + cfg_dict[workload].append(record) + + cache_dict = {} + for key in self._in_nodes_dict: + node_entry = self._node_list[key] + if node_entry["op"] not in self._target_ops: + continue + workload = node_entry["workloads"][0] + if workload in cache_dict: + node_entry["record_candidates"] = cache_dict[workload] + continue + record_candidates = [] + infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + layout_tracking_dict = {} + for record in cfg_dict[workload]: + in_measure, out_measure = record + workload = in_measure.task.workload + cfg = in_measure.config + # For multiple cfgs which produces the same in/out layouts, + # only the most efficient one is preserved. + with self._target: + layouts = infer_layout_func(workload, cfg) + if layouts in layout_tracking_dict: + cost = out_measure.costs[0] + current_best_cost = layout_tracking_dict[layouts][1].costs[0] + if cost < current_best_cost: + layout_tracking_dict[layouts] = record + else: + layout_tracking_dict[layouts] = record + sorted_records = sorted(layout_tracking_dict.values(), + key=lambda item: item[1].costs[0]) + for i in range(min(self._max_sch_num, len(sorted_records))): + record_candidates.append(sorted_records[i]) + node_entry["record_candidates"] = record_candidates + cache_dict[workload] = record_candidates + + def _iterate_layout_transform(self, callback): + """Iterate all possible layout transformations and execute callback for each + iteration. callback function accepts 6 arguments: from_node_idx, to_node_idx, + from_sch_idx, to_sch_idx, args which represent the argument list of layout + transformation and is_valid showing whether this is a valid layout transformation. + """ + input_names = self._input_shapes.keys() + for key, val in self._in_nodes_dict.items(): + node_entry = self._node_list[key] + target_input_idx = -1 + target_input_pos = -1 + if has_multiple_inputs(self._node_list, key, input_names): + for i, item in enumerate(val): + if not is_input_node(self._node_list[item], input_names): + target_input_idx = item + target_input_pos = i + break + + for i, item in enumerate(val): + i_idx = item + in_node_entry = self._node_list[i_idx] + if is_input_node(in_node_entry, input_names): + continue + + if node_entry["op"] in self._target_ops: + o_idx = key + o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + o_wkl = node_entry["workloads"][0] + i_topi_op = in_node_entry["topi_op"][0] + i_wkl = in_node_entry["workloads"][0] + pivot = 0 + while not i_wkl: + pivot += 1 + i_topi_op = in_node_entry["topi_op"][pivot] + i_wkl = in_node_entry["workloads"][pivot] + i_infer_layout_func = OP2LAYOUT[i_topi_op] + else: + o_idx = target_input_idx + if i <= target_input_pos: + continue + o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]] + o_wkl = node_entry["workloads"][target_input_pos] + i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]] + i_wkl = node_entry["workloads"][i] + + + for m, i_record in enumerate(in_node_entry["record_candidates"]): + for n, o_record in enumerate(node_entry["record_candidates"]): + i_cfg, o_cfg = i_record[0].config, o_record[0].config + with self._target: + i_input_info, i_output_info = i_infer_layout_func(i_wkl, i_cfg) + o_input_info, o_output_info = o_infer_layout_func(o_wkl, o_cfg) + if len(i_input_info) > 1 or len(i_output_info) > 1 or \ + len(o_input_info) > 1 or len(o_output_info) > 1: + raise RuntimeError("Graph tuner only supports target operator " + "with single input and single output. " + "Please check target_ops argument.") + + in_shape, in_layout = i_output_info[0] + if node_entry["op"] in self._target_ops: + _, out_layout = o_input_info[0] + else: + _, out_layout = o_output_info[0] + data_placeholder = tvm.placeholder(in_shape, name="data", + dtype=self._dtype) + args = [data_placeholder, in_layout, out_layout] + callback(i_idx, o_idx, m, n, args) + + + def _create_matrix_callback(self, from_node_idx, to_node_idx, from_sch_idx, + to_sch_idx, args): + """Create dictionary containing matrix format of layout transformation + between nodes.""" + sargs = serialize_args(args) + in_layout, out_layout = args[1], args[2] + ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(sargs) + idx_pair_key = (from_node_idx, to_node_idx) + + if in_layout == out_layout: + layout_transform_time = 0 + else: + layout_transform_time = \ + self._layout_transform_perf_records[ltf_workload][1].costs[0] + + if idx_pair_key not in self._layout_transform_interlayer_cost: + self._layout_transform_interlayer_cost[idx_pair_key] = [] + if len(self._layout_transform_interlayer_cost[idx_pair_key]) <= from_sch_idx: + self._layout_transform_interlayer_cost[idx_pair_key].append([]) + self._layout_transform_interlayer_cost[idx_pair_key][from_sch_idx]\ + .append(layout_transform_time) + + def benchmark_layout_transform(self, min_exec_num=100, timeout=10, + use_rpc=False, device_key=None, host="localhost", + port=9190, n_parallel=1, build_func='default', + layout_records=None, target_host=None, infer_layout=False): + """Benchmark all possible layout transformation in the graph, + given a set of schedule candidates for each workload of target operator. + + Parameters + ---------- + min_exec_num : int, optional + Minimum number of execution. Final execution time is the average of + all execution time. + + timeout : int, optional + Time out for each execution. + + use_rpc : boolean, optional + Whether to use rpc mode for benchmarking. + + device_key : str, optional + Remote device key which can be queried by + python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190 + + host : str, optional + IP address used to create RPC tracker on host machine. + + port : int, optional + Port number used to create RPC tracker on host machine. + + n_parallel: int, optional + The number of measurement task that can run in parallel. + Set this according to the number of cpu cores (for compilation) and + the number of devices you have (for measuring generate code). + + build_func: str or callable, optional + 'default': call default builder. This works for normal target (llvm, cuda) + + 'ndk': use Android NDK to create shared library. Use this for android target. + + callable: customized build function for other backends (e.g. VTA). + See autotvm/measure/measure_methods.py::default_build_func for example. + + layout_records : str or iterator of (MeasureInput, MeasureResult). optional + Collection of layout_transform benchmarking records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + + If this argument is set, graph tuner will first check whether layout_transform + workload already exists in records and skip benchmarking if possible. + + target_host : str, optional + str or :any:`tvm.target.Target` optional + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + + infer_layout : bool, optional + Whether to infer layout transformation time if it doesn't exist in records, instead + of benchmarking on target device. + + This might bring performance loss comparing to benchmarking layout transformation. + """ + self._logger.info("Start to benchmark layout transformation...") + if layout_records is None and infer_layout: + raise RuntimeError("Requires some records to infer layout transformation time.") + + if isinstance(layout_records, str): + layout_records = load_from_file(layout_records) + if not layout_records and infer_layout: + raise RuntimeError("Records must be non-empty to infer layout transformation time.") + + if isinstance(layout_records, str): + layout_records = load_from_file(layout_records) + num_flops, total_time = 0, 0 + if layout_records is not None: + for record in layout_records: + ltf_wkl = record[0].task.workload + self._layout_transform_perf_records[ltf_wkl] = record + input_shape = ltf_wkl[1][1] + flops = np.prod(input_shape) + num_flops += flops + total_time += record[1].costs[0] + avg_time = total_time / num_flops if num_flops > 0 else 0 + + args_list = [] + def _fetch_args_callback(from_node_idx, to_node_idx, from_sch_idx, + to_sch_idx, args): + """Callback function to fetch layout transform args""" + _, in_layout, out_layout = args + if in_layout != out_layout: + args_list.append(args) + + self._iterate_layout_transform(_fetch_args_callback) + + def _log_to_list(record_list): + """Callback to log result to a list.""" + def _callback(_, inputs, results): + """Callback implementation""" + record_list.append((inputs[0], results[0])) + return _callback + + builder = autotvm.LocalBuilder(n_parallel=n_parallel, build_func=build_func) + runner = autotvm.LocalRunner(number=min_exec_num, repeat=1, timeout=timeout) + if use_rpc: + if device_key is None: + raise RuntimeError("device_key need to be set to use rpc tracker mode.") + runner = autotvm.measure.RPCRunner(device_key, host, port, n_parallel=n_parallel, + number=min_exec_num, repeat=1, + timeout=timeout) + measure_option = autotvm.measure_option(builder=builder, runner=runner) + for args in args_list: + args = serialize_args(args) + ltf_workload = ('layout_transform',) + autotvm.task.args_to_workload(args) + if ltf_workload in self._layout_transform_perf_records: + continue + + if infer_layout: + input_shape = ltf_workload[1][1] + flops = 1 + for i in input_shape: + flops *= i + inferred_time = flops * avg_time + record_input = MeasureInput(target=self._target, task=None, config=None) + record_output = MeasureResult(costs=(inferred_time,), error_no=0, + all_cost=-1, timestamp=-1) + self._layout_transform_perf_records[ltf_workload] = (record_input, record_output) + continue + + records = [] + task = autotvm.task.create(layout_transform, args=args, target=self._target, + target_host=target_host) + task.workload = ltf_workload + tuner = autotvm.tuner.GridSearchTuner(task) + tuner.tune(n_trial=1, measure_option=measure_option, + callbacks=[_log_to_list(records)]) + if not isinstance(records[0][1].costs[0], float): + records[0] = (records[0][0], records[0][1]._replace(costs=(INVALID_LAYOUT_TIME,))) + self._layout_transform_perf_records[ltf_workload] = records[0] + + self._iterate_layout_transform(self._create_matrix_callback) + self._logger.info("Benchmarking layout transformation successful.") + + @property + def layout_transform_perf_records(self): + """Get layout transformation dictionary for input graph. + + Returns + ------- + layout_transform_perf_records : dict of tuple to (MeasureInput, MeasureResult) + Layout transformation dictionary for input graph. + """ + return self._layout_transform_perf_records + + + def get_optimal_records(self): + """Convert optimal record dictionary to a list of records + with ascending order of node index in graph. + + Returns + ------- + sch_list : list of tuple + List of records with ascending order of node index in graph. + """ + ordered_index_list = sorted(self._optimal_record_dict.keys()) + ret = [] + for index in ordered_index_list: + node_entry = self._node_list[index] + if node_entry["op"] not in self._target_ops: + continue + ret.append(node_entry["record_candidates"][self._optimal_record_dict[index]]) + return ret + + def write_opt_sch2record_file(self, record_file="graph_opt_schedule.log"): + """Write graph level optimal schedules into file. + + Parameters + ---------- + record_file : str, optional + Output schedule file. + """ + with open(record_file, "a") as out_file: + records = self.get_optimal_records() + for record in records: + out_file.write(encode(record[0], record[1]) + "\n") + msg = "Writing optimal schedules to %s successfully." % record_file + self._logger.info(msg) + + @abstractmethod + def run(self, **kwargs): + """Run graph tuning.""" + pass diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py new file mode 100644 index 000000000000..4a512c224a1d --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/dynamic_programming_stage.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-instance-attributes,too-many-branches,too-many-statements,too-many-arguments,too-many-locals,invalid-name +"""Stage class for dynamic programming tuner""" +import numpy as np + +from .utils import is_input_node + + +class DPStage(object): + """Class to represent node in Markov decision process. A stage has states + to represent different schedules of the current node. Since in this problem + the action is the schedule selected for current node, action can be fully + represented by states. No extra attribute needs for action. + + In most cases, instance of this class should be created through DPTuner. + """ + def __init__(self, idx, input_shapes, node_list, + counted_nodes_set, layout_transform_interlayer_cost, + stage_dict, in_nodes_dict, out_nodes_dict, + dep_dict, target_ops, dtype="float32"): + """Initialize a stage and create all states. + + Parameters + ---------- + idx : int + Index for current node. + + input_shapes : dict of string to tuple of int + Input shapes for current graph. + + node_list : list of dict + List of all nodes for current graph. + + counted_nodes_set : set of int + Global set recording whether the execution time of a node has been counted. + + layout_transform_interlayer_cost : dict of tuple to list + Dictionary maps node index pair to layout transformation time between them. + + stage_dict : dict of int to Stage + Global dictionary for all stages mapping node index to stage. + + in_nodes_dict : dict of int to list of int + Dictionary maps node index to corresponding input node index. + + out_nodes_dict : dict of int to list of int + Dictionary maps node index to corresponding output node index. + + dep_dict : dict of int to set of int + Dictionary maps node index to dependent node index. + + target_ops : list of str + Target operators + + dtype : str, optional + Data type. + """ + self._global_input_shapes = input_shapes + self._global_input_names = input_shapes.keys() + self._global_node_list = node_list + self._global_counted_nodes_set = counted_nodes_set + self._global_layout_transform_interlayer_cost = layout_transform_interlayer_cost + self._global_stage_dict = stage_dict + self._global_in_nodes_dict = in_nodes_dict + self._global_out_nodes_dict = out_nodes_dict + self._global_dep_dict = dep_dict + + self._idx = idx + self._node_entry = self._global_node_list[idx] + self._target_ops = target_ops + self._wkl = self._node_entry["workloads"][0] + self._record_list = self._node_entry["record_candidates"] + self._dep = [] + self._dtype = dtype + self._states = None + self._full_states = None + self._full_states_idx = None + self._create_states() + + def _create_states(self): + """Create states.""" + node = self._global_node_list[self._idx] + if node["op"] in self._target_ops: + self._create_op_states() + else: + self._create_multi_inputs_states() + + def _create_op_states(self): + """State creation routine for nodes with target_op.""" + input_idx = -1 + for index in self._global_in_nodes_dict[self._idx]: + input_idx = index + if not is_input_node(self._global_node_list[input_idx], + self._global_input_names): + break + + if is_input_node(self._global_node_list[input_idx], + self._global_input_names): + self._full_states = np.array([record[1].costs[0] + for record in self._record_list]) + self._states = self._full_states + else: + input_node_entry = self._global_node_list[input_idx] + input_stage = self._global_stage_dict[input_idx] + input_dep = input_stage.dep + input_states = input_stage.states + input_flatten_states = input_states.flatten() + input_record_list = input_node_entry["record_candidates"] + num_schedules = len(self._record_list) + num_input_schedules = len(input_record_list) + num_input_states = input_flatten_states.shape[0] + + full_states_shape = tuple([num_schedules, num_input_schedules] + + [len(self._global_node_list[dep_idx]["record_candidates"]) + for dep_idx in input_dep]) + self._full_states = np.zeros(full_states_shape).flatten().astype("float32") + self._full_states_idx = [self._idx, input_idx] + input_dep + dep_multiplier = 1 + for i in range(2, len(full_states_shape)): + dep_multiplier *= full_states_shape[i] + input_node_time_counted = input_idx in self._global_counted_nodes_set + + for i in range(num_schedules): + current_sch_time = float(self._record_list[i][1].costs[0]) + for j in range(num_input_states): + input_sch_idx = j // dep_multiplier + layout_transform_time = \ + self._global_layout_transform_interlayer_cost \ + [(input_idx, self._idx)][input_sch_idx][i] + + if input_node_time_counted: + total_time = current_sch_time + layout_transform_time + else: + total_time = \ + current_sch_time + layout_transform_time + input_flatten_states[j] + current_state_idx = i * num_input_states + j + self._full_states[current_state_idx] = total_time + + if not input_node_time_counted: + self._global_counted_nodes_set.add(input_idx) + self._full_states = self._full_states.reshape(full_states_shape) + + # If out degree of input node is 1, we can remove the dimension of input node, + # since the states of input node will not be needed any more. Otherwise, input + # node should become a dependency. + if len(self._global_out_nodes_dict[input_idx]) == 1: + self._states = np.amin(self._full_states, axis=1) + self._dep = list(input_dep) + else: + self._states = self._full_states + self._dep = [input_idx,] + input_dep + + # Update global dependency dictionary. + # This is to monitor the dependency states to decide + # when a dependency can be eliminated, so that total + # number of states can be largely reduced. + for dep_idx in self._dep: + self._global_dep_dict[dep_idx].remove(self._idx) + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[dep_idx].add(child) + if len(self._global_out_nodes_dict[self._idx]) > 1: + self._global_dep_dict[self._idx] = set() + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[self._idx].add(child) + + def _create_multi_inputs_states(self): + """State creation routine for multi_input operator + + In tvm, layout transformation for an elemwise-like follow the rule which + all input operators transform their layouts to the leftmost input operator + layout. For example: + elemwise-sum + | | | + | | | + op0 op1 op2 + In this block, the possible layout transformations are: op1 -> op0 and op2 -> op0. + In graph tuning, a 3-D array with shape (k0, k1, k2) can represent the layout + transformations between these three nodes. It is also possible some earlier states + belong to other nodes(We name them as dependency) are required for dynamic programming. + The final states array for this elemwise-sum can be with shape (e0, k0, k1, e1, k2). + To iterate through all states, we first align the shape of op0, op1 and op2 to be + (e0, k0, k1, e1, k2) by broadcasting the original states. We also record the axis of + each input node in the states array, together with the multiplier. For example, + the axis index for op0 is 1, and multiplier is k1 * e1 * k2. If current iterating index + in the flatten array is i, the index of op0 can be computed as: + i % (k0 * k1 * e1 * k2) // (k1 * e1 * k2). + """ + full_input_node_list = list(self._global_in_nodes_dict[self._idx]) + input_index_list = [] + # Remove input and parameter nodes + for input_idx in full_input_node_list: + if not is_input_node(self._global_node_list[input_idx], + self._global_input_names): + input_index_list.append(input_idx) + + # Generate new states + states_list, aligned_node_list = DPStage.align_states(input_index_list, + self._global_stage_dict, + self._global_node_list) + target_node_idx, target_major_axis, target_multiplier, target_states = states_list[0] + aligned_shape = target_states.shape + self._full_states = np.zeros(aligned_shape).astype("float32").flatten() + self._full_states_idx = list(aligned_node_list) + num_states = self._full_states.shape[0] + node_time_counted = [item[0] in self._global_counted_nodes_set for item in states_list] + target_states = target_states.flatten() + src_states_list = [states_list[i][3].flatten() for i in range(1, len(states_list))] + + for i in range(num_states): + target_sch_idx = (i % (target_multiplier * + aligned_shape[target_major_axis])) // target_multiplier + if node_time_counted[0]: + new_state = 0 + else: + new_state = target_states[i] + + for j in range(1, len(states_list)): + src_states = src_states_list[j - 1] + src_node_idx, src_major_axis, src_multiplier, _ = states_list[j] + src_sch_idx = (i % (src_multiplier * + aligned_shape[src_major_axis])) // src_multiplier + layout_transform_time = \ + self._global_layout_transform_interlayer_cost\ + [(src_node_idx, target_node_idx)][src_sch_idx][target_sch_idx] + + if node_time_counted[j]: + new_state += layout_transform_time + else: + new_state += layout_transform_time + src_states[i] + self._full_states[i] = new_state + + for i, node_counted in enumerate(node_time_counted): + if not node_counted: + self._global_counted_nodes_set.add(states_list[i][0]) + self._full_states = self._full_states.reshape(aligned_shape) + + # Remove dependency to reduce states + reduced_states = np.array(self._full_states) + reduced_states_transpose = [states_list[0][1]] + reduced_states_dep_list = [] + self._dep = [] + for i in range(len(reduced_states.shape)): + if i != states_list[0][1]: + reduced_states_transpose.append(i) + reduced_states_dep_list.append(aligned_node_list[i]) + reduced_states = np.transpose(reduced_states, reduced_states_transpose) + shift = 0 + for i, dep in enumerate(reduced_states_dep_list): + if dep not in self._global_dep_dict or len(self._global_dep_dict[dep]) == 1: + self._global_dep_dict.pop(dep, None) + reduced_states = np.amin(reduced_states, axis=i+1-shift) + shift += 1 + else: + self._dep.append(dep) + self._states = reduced_states + + # Update dependency + for dep in self._dep: + self._global_dep_dict[dep].remove(self._idx) + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[dep].add(child) + if len(self._global_out_nodes_dict[self._idx]) > 1: + self._global_dep_dict[self._idx] = set() + for child in self._global_out_nodes_dict[self._idx]: + self._global_dep_dict[self._idx].add(child) + + @property + def dep(self): + """Get dependency list.""" + return self._dep + + @property + def states(self): + """Get states.""" + return self._states + + @property + def full_states(self): + """Get complete states.""" + return self._full_states + + @property + def full_states_idx(self): + """Get node index of complete states.""" + return self._full_states_idx + + @staticmethod + def align_states(input_index_list, stage_dict, node_list): + """Align all input node states shapes to be the same and transpose/reshape properly. + + This is used in creating multi_input operator states. + + Parameters + ---------- + input_index_list : list of int + List of input node index. + + stage_dict : dict of int to Stage + Global dictionary of node index to stage. + + node_list : list of dict + List of all nodes for current graph. + + Returns + ------- + states_list : list of tuple + List of aligned states. + + aligned_node_list : list in int + List of node index for aligned states. + """ + aligned_node_list = list(input_index_list) + states_list = [] + for input_idx in input_index_list: + input_node_stage = stage_dict[input_idx] + for dep_idx in input_node_stage.dep: + if dep_idx not in aligned_node_list: + aligned_node_list.append(dep_idx) + aligned_shape = tuple([len(node_list[idx]["record_candidates"]) + for idx in aligned_node_list]) + for input_idx in input_index_list: + input_node_stage = stage_dict[input_idx] + input_node_shape_idx_list = [input_idx] + input_node_stage.dep + transpose_idx_list = [] + reshape_list = [] + major_axis = -1 + for i, idx in enumerate(aligned_node_list): + if input_idx == idx: + major_axis = i + if idx in input_node_shape_idx_list: + transpose_idx_list.append(idx) + reshape_list.append(aligned_shape[i]) + else: + reshape_list.append(1) + transpose_list = [input_node_shape_idx_list.index(idx) for idx in transpose_idx_list] + input_node_states = np.transpose(input_node_stage.states, tuple(transpose_list)) + input_node_states = np.reshape(input_node_states, tuple(reshape_list)) + input_node_states = np.broadcast_to(input_node_states, aligned_shape) + multiplier = 1 + for i in range(major_axis + 1, len(aligned_shape)): + multiplier *= aligned_shape[i] + states_list.append((input_idx, major_axis, multiplier, input_node_states)) + return states_list, aligned_node_list diff --git a/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py new file mode 100644 index 000000000000..11571f2bdef9 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/dynamic_programming_tuner.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-error,too-many-locals,too-many-statements,too-many-branches,unused-variable +"""Dynamic programming tuner.""" +import sys +import numpy as np + +from .base_graph_tuner import BaseGraphTuner +from .dynamic_programming_stage import DPStage +from .utils import has_multiple_inputs, is_input_node + +if sys.version_info[0] == 3: + import queue +else: + import Queue as queue + +class DPTuner(BaseGraphTuner): + """Tuner which uses dynamic programming to solve MDP problem. + + Note: currently dynamic programming is used to solve this MDP problem. However, + this problem is intrinsically non-polynomial. DP can't apply for more complicated + models, such as networks with many element-wise sum operators. In this case, switch + to heuristic algorithm such as PBQP tuner. + """ + def __init__(self, *args, **kwargs): + """Create a dynamic programming tuner. + """ + super(DPTuner, self).__init__(*args, **kwargs) + self._num_states = self._max_num_states = None + self._stage_dict = {} + self._dep_dict = {} + self._counted_nodes_set = set() + + self._global_data_dict = { + "dtype": self._dtype, + "counted_nodes_set": self._counted_nodes_set, + "stage_dict": self._stage_dict, + "in_nodes_dict": self._in_nodes_dict, + "out_nodes_dict": self._out_nodes_dict, + "dep_dict": self._dep_dict, + "node_list": self._node_list, + "input_shapes": self._input_shapes, + "layout_transform_interlayer_cost": self._layout_transform_interlayer_cost + } + + def _check_num_states(self, num_states): + """Track the number of states.""" + self._num_states += num_states + if self._max_num_states is not None: + if self._num_states > self._max_num_states: + raise RuntimeError("Too many states detected while running dynamic " + "programming: got %d states but upper limit is %d." % + (self._num_states, self._max_num_states)) + + def _forward(self): + """Forward pass in DP to generate states for all stages. + """ + self._logger.info("Start forward pass...") + for node_idx in sorted(self._in_nodes_dict.keys()): + stage = DPStage(idx=node_idx, target_ops=self._target_ops, + **self._global_data_dict) + self._check_num_states(stage.full_states.size) + self._stage_dict[node_idx] = stage + self._logger.info("Finished forward pass.") + + def _backward(self): + """Backward pass in DP to generate optimal solution. + """ + self._logger.info("Start backward pass...") + input_names = self._input_shapes.keys() + optimal_record_dict = {} + # Pick optimal schedule for output nodes + output_idx_list = [] + for key, val in self._out_nodes_dict.items(): + if not val: + output_idx_list.append(key) + states_list, aligned_node_list = DPStage.align_states(output_idx_list, self._stage_dict, + self._node_list) + num_states = states_list[0][3].size + self._check_num_states(num_states * len(output_idx_list)) + aligned_node_shape = states_list[0][3].shape + min_time = 0 + min_pos = -1 + for states in states_list: + min_time += np.amax(states[3]) + flatten_states_list = [current_states[3].flatten() for current_states in states_list] + for i in range(num_states): + current_time = 0 + for j, current_states in enumerate(states_list): + current_time += flatten_states_list[j][i] + if min_time > current_time: + min_time = current_time + min_pos = i + for i, states in enumerate(states_list): + current_major_axis = states[1] + current_sch_idx = (min_pos % (states[2] * + aligned_node_shape[current_major_axis])) // states[2] + optimal_record_dict[aligned_node_list[i]] = current_sch_idx + # Pick optimal schedule for dependencies of output nodes + for i in range(len(states_list), len(aligned_node_list)): + multiplier = 1 + for j in range(i + 1, len(aligned_node_list)): + multiplier *= aligned_node_shape[j] + optimal_record_dict[aligned_node_list[i]] = \ + min_pos // multiplier % aligned_node_shape[i] + + # Backward pass to get optimal schedules for other nodes + bfs_q = queue.Queue() + visited = set() + for out_idx in output_idx_list: + bfs_q.put(out_idx) + while not bfs_q.empty(): + node_idx = bfs_q.get() + visited.add(node_idx) + if is_input_node(self._node_list[node_idx], input_names): + continue + optimal_sch_idx = optimal_record_dict[node_idx] + full_states = self._stage_dict[node_idx].full_states + if not has_multiple_inputs(self._node_list, node_idx, input_names): + input_idx = self._in_nodes_dict[node_idx][0] + if is_input_node(self._node_list[input_idx], input_names): + continue + if input_idx not in visited: + bfs_q.put(input_idx) + if input_idx not in optimal_record_dict: + dep_list = self._stage_dict[node_idx].dep + dep_idx = tuple([optimal_record_dict[item] for item in dep_list]) + tmp = np.argmin(full_states, axis=1) + optimal_input_sch_idx = tmp[(optimal_sch_idx,) + dep_idx] + optimal_record_dict[input_idx] = optimal_input_sch_idx + else: + input_idx_list = self._in_nodes_dict[node_idx] + optimal_record_dict[input_idx_list[0]] = optimal_sch_idx + full_states_idx = self._stage_dict[node_idx].full_states_idx + tmp = full_states[optimal_sch_idx] + new_states_idx, new_states_pos = [], [] + visited_states_idx, visited_states_pos = [], [] + for i in range(1, len(full_states_idx)): + if full_states_idx[i] in optimal_record_dict: + visited_states_idx.append(full_states_idx[i]) + visited_states_pos.append(i - 1) + else: + new_states_idx.append(full_states_idx[i]) + new_states_pos.append(i - 1) + if visited_states_idx: + tmp = np.transpose(tmp, tuple(visited_states_pos + new_states_pos)) + tmp = tmp[tuple([optimal_record_dict[idx] for idx in visited_states_idx])] + min_pos = np.argmin(tmp) + multiplier = 1 + for i in range(len(new_states_idx)): + multiplier *= full_states.shape[new_states_pos[i] + 1] + for pos, idx in zip(new_states_pos, new_states_idx): + multiplier //= full_states.shape[pos + 1] + optimal_record_dict[idx] = min_pos // multiplier + min_pos %= multiplier + for input_idx in input_idx_list: + if input_idx not in visited: + bfs_q.put(input_idx) + + self._optimal_record_dict = optimal_record_dict + for node_idx, _ in self._in_nodes_dict.items(): + if self._node_list[node_idx]["op"] not in self._target_ops: + continue + self._logger.info("Finished backward pass...") + + def run(self, **kwargs): + """Run dynamic programming solver. + """ + max_num_states = None if "max_num_states" not in kwargs else kwargs["max_num_states"] + self._num_states = 0 + self._max_num_states = max_num_states + self._logger.info("Start to run dynamic programming algorithm...") + self._forward() + self._backward() + self._logger.info("Finished DPExecutor run.") diff --git a/python/tvm/autotvm/graph_tuner/pbqp_tuner.py b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py new file mode 100644 index 000000000000..1d7089ef248b --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/pbqp_tuner.py @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,too-many-locals +"""Partitioned Boolean Quadratic Programming Tuner""" +from ._base import INVALID_LAYOUT_TIME +from .base_graph_tuner import BaseGraphTuner +from .utils import is_input_node, has_multiple_inputs + + +class PBQPTuner(BaseGraphTuner): + """An approximation method to deal with intractably + large size of graph tuning problem. + + This graph coloring algorithm mainly comes from: + + Lang Hames and Bernhard Scholz. + Nearly optimal register allocation with pbqp.JMLC 2006. + LNCS, vol.4228,pp. 346-361, 2016 + """ + def __init__(self, *args, **kwargs): + """Create a partitioned boolean quadratic programming tuner. + """ + super(PBQPTuner, self).__init__(*args, **kwargs) + + # Remove input nodes + input_names = self._input_shapes.keys() + for node_idx in self._out_nodes_dict: + if is_input_node(self._node_list[node_idx], input_names): + for out_node_idx in self._out_nodes_dict[node_idx]: + self._in_nodes_dict[out_node_idx].remove(node_idx) + + self._adj_dict = {} + for node_idx in self._in_nodes_dict: + self._adj_dict[node_idx] = list(self._in_nodes_dict[node_idx]) + \ + list(self._out_nodes_dict[node_idx]) + + self._record_cost_dict = {} + for key in self._in_nodes_dict: + self._record_cost_dict[key] = [] + for record in self._node_list[key]["record_candidates"]: + self._record_cost_dict[key].append(record[1].costs[0]) + + self._max_degree = -1 + self._node_degree_dict = {} + for node_idx in self._in_nodes_dict: + node_degree = self._get_degree(node_idx) + self._node_degree_dict[node_idx] = node_degree + self._max_degree = max(self._max_degree, node_degree) + + self._stack = [] + self._buckets = [[] for _ in range(self._max_degree + 2)] + for node_idx in sorted(self._in_nodes_dict): + node_degree = self._get_degree(node_idx) + self._buckets[node_degree].append(node_idx) + + self._is_optimal = True + + def _get_degree(self, node_idx): + """Get node degree. + """ + return len(self._adj_dict[node_idx]) + + def _reorder_adj_nodes(self, node_idx): + """Update buckets list with current adjacency list. + """ + for adj_node in self._adj_dict[node_idx]: + current_degree = self._get_degree(adj_node) + prev_degree = self._node_degree_dict[adj_node] + if prev_degree != current_degree: + self._buckets[prev_degree].remove(adj_node) + self._buckets[current_degree].insert(0, adj_node) + self._node_degree_dict[adj_node] = current_degree + + def _remove_node(self, node_idx): + """Remove node from graph. Update adjacency list accordingly. + """ + node_degree = self._get_degree(node_idx) + self._buckets[node_degree].remove(node_idx) + for adj_node in self._adj_dict[node_idx]: + self._adj_dict[adj_node].remove(node_idx) + + def _insert_edge(self, node_x, node_y, adj_cost_matrix): + """Insert an edge between two nodes. + """ + self._layout_transform_interlayer_cost[(node_x, node_y)] = adj_cost_matrix + self._layout_transform_interlayer_cost[(node_y, node_x)] = [] + for i in range(len(adj_cost_matrix[0])): + self._layout_transform_interlayer_cost[(node_y, node_x)].append([]) + for cost_vec in adj_cost_matrix: + self._layout_transform_interlayer_cost[(node_y, node_x)][i] \ + .append(cost_vec[i]) + + self._adj_dict[node_x].append(node_y) + self._adj_dict[node_y].append(node_x) + + def _backward_insert_node(self, node_idx): + """Reinsert node in backward pass. + """ + for adj_node in self._adj_dict[node_idx]: + self._adj_dict[adj_node].append(node_idx) + + def _RI_reduction(self, node_idx): + """Reduce nodes with degree 1. + """ + adj_node = self._adj_dict[node_idx][0] + ltf_matrix = self._layout_transform_interlayer_cost[(adj_node, node_idx)] + for i, cost_vec in enumerate(ltf_matrix): + min_cost = INVALID_LAYOUT_TIME + for j, cost in enumerate(cost_vec): + min_cost = min(min_cost, cost + self._record_cost_dict[node_idx][j]) + self._record_cost_dict[adj_node][i] += min_cost + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _RII_reduction(self, node_idx): + """Reduce nodes with degree 2. + """ + adj_node_x, adj_node_y = self._adj_dict[node_idx] + ltf_matrix_x = self._layout_transform_interlayer_cost[(adj_node_x, node_idx)] + ltf_matrix_y = self._layout_transform_interlayer_cost[(adj_node_y, node_idx)] + delta_matrix = [[] for _ in range(len(ltf_matrix_x))] + for i, cost_vec_x in enumerate(ltf_matrix_x): + for j, cost_vec_y in enumerate(ltf_matrix_y): + min_cost = INVALID_LAYOUT_TIME + for k in range(len(self._record_cost_dict[node_idx])): + min_cost = min(min_cost, cost_vec_x[k] + cost_vec_y[k] + + self._record_cost_dict[node_idx][k]) + delta_matrix[i].append(min_cost) + + if adj_node_x == adj_node_y: + for i, delta_row in enumerate(delta_matrix): + self._record_cost_dict[adj_node_x][i] += delta_row[i] + elif adj_node_x in self._adj_dict[adj_node_y]: + for i, _ in enumerate(delta_matrix): + for j, delta in enumerate(delta_matrix[i]): + self._layout_transform_interlayer_cost[(adj_node_x, adj_node_y)][i][j] \ + += delta + self._layout_transform_interlayer_cost[(adj_node_y, adj_node_x)][j][i] \ + += delta + else: + self._insert_edge(adj_node_x, adj_node_y, delta_matrix) + + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _RN_reduction(self, node_idx): + """Reduce nodes with degree greater than 2. + """ + min_cost = INVALID_LAYOUT_TIME + record_idx = -1 + + for i, record_cost in enumerate(self._record_cost_dict[node_idx]): + current_cost = record_cost + for adj_node in self._adj_dict[node_idx]: + ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)] + adj_record_cost = list(self._record_cost_dict[adj_node]) + for j, ltf_cost in enumerate(ltf_matrix[i]): + adj_record_cost[j] += ltf_cost + current_cost += min(adj_record_cost) + if current_cost < min_cost: + min_cost = current_cost + record_idx = i + + if record_idx < 0: + raise RuntimeError("Can't find a soltuion for node %d when " + "applying RN reduction" % node_idx) + self._optimal_record_dict[node_idx] = record_idx + self._is_optimal = False + + for adj_node in self._adj_dict[node_idx]: + ltf_matrix = self._layout_transform_interlayer_cost[(node_idx, adj_node)] + for i, ltf_cost in enumerate(ltf_matrix[record_idx]): + self._record_cost_dict[adj_node][i] += ltf_cost + + self._remove_node(node_idx) + self._reorder_adj_nodes(node_idx) + self._stack.append(node_idx) + + def _forward(self): + """Forward pass in PBQP to reduce nodes. + """ + while True: + if self._buckets[1]: + node_idx = self._buckets[1][0] + self._RI_reduction(node_idx) + elif self._max_degree >= 2 and self._buckets[2]: + node_idx = self._buckets[2][0] + self._RII_reduction(node_idx) + elif self._max_degree >= 3: + max_degree_node = -1 + for i in range(self._max_degree, 2, -1): + if self._buckets[i]: + max_degree_node = self._buckets[i][0] + self._RN_reduction(max_degree_node) + break + if max_degree_node < 0: + break + else: + break + + def _backward(self): + """Backward pass in PBQP to generate optimal solution. + """ + # Solve nodes left in the forward graph + for node_idx in self._buckets[0]: + record_costs = self._record_cost_dict[node_idx] + min_cost = min(record_costs) + self._optimal_record_dict[node_idx] = record_costs.index(min_cost) + + # Solve nodes with one or two degrees + for node_idx in reversed(self._stack): + self._backward_insert_node(node_idx) + if node_idx not in self._optimal_record_dict: + record_costs = list(self._record_cost_dict[node_idx]) + for adj_node in self._adj_dict[node_idx]: + adj_optimal_idx = self._optimal_record_dict[adj_node] + for i, _ in enumerate(record_costs): + record_costs[i] += \ + self._layout_transform_interlayer_cost \ + [(node_idx, adj_node)][i][adj_optimal_idx] + min_cost = min(record_costs) + self._optimal_record_dict[node_idx] = record_costs.index(min_cost) + + def run(self, **kwargs): + """Run partitioned boolean quadratic programming tuner. + """ + self._logger.info("Start to run PBQP algorithm...") + # Define virtual record lists and layout transformaton matrices + # for multi-input nodes. + input_names = self._input_shapes.keys() + temp = {} + for key, val in self._in_nodes_dict.items(): + target_input_idx = -1 + target_input_pos = -1 + if has_multiple_inputs(self._node_list, key, input_names): + for i, item in enumerate(val): + if not is_input_node(self._node_list[item], input_names): + target_input_idx = item + target_input_pos = i + break + temp[(target_input_idx, key)] = [] + record_candidates = self._node_list[target_input_idx]["record_candidates"] + for j in range(len(record_candidates)): + temp[(target_input_idx, key)].append([]) + for k in range(len(record_candidates)): + temp[(target_input_idx, key)][j].append(0 if j == k + else INVALID_LAYOUT_TIME) + + for j in range(target_input_pos + 1, len(val)): + input_idx = val[j] + if is_input_node(self._node_list[input_idx], input_names): + continue + temp[(input_idx, key)] = \ + self._layout_transform_interlayer_cost[(input_idx, target_input_idx)] + self._layout_transform_interlayer_cost.update(temp) + + # Create reverse layout transformation matrices + temp = {} + for idx_pair, ltf_matrix in self._layout_transform_interlayer_cost.items(): + reverse_key = (idx_pair[1], idx_pair[0]) + reverse_matrix = [[] for _ in range(len(ltf_matrix[0]))] + for i, _ in enumerate(ltf_matrix): + for j, ltf in enumerate(ltf_matrix[i]): + reverse_matrix[j].append(ltf) + temp[reverse_key] = reverse_matrix + self._layout_transform_interlayer_cost.update(temp) + + self._forward() + self._backward() + is_optimal = "optimal" if self._is_optimal else "sub-optimal" + msg = "Finished PBQPExecutor run. Got %s solution." % is_optimal + self._logger.info(msg) diff --git a/python/tvm/autotvm/graph_tuner/utils/__init__.py b/python/tvm/autotvm/graph_tuner/utils/__init__.py new file mode 100644 index 000000000000..8b36e752bdef --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Graph tuner utility functions""" +from __future__ import absolute_import + +from . import traverse_graph +from . import utils + +from .traverse_graph import expr2graph, get_direct_ancestor, get_in_nodes, \ + get_out_nodes +from .utils import has_multiple_inputs, is_input_node, bind_inputs diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py new file mode 100644 index 000000000000..08f1017e7fb8 --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=too-many-locals,too-many-statements,too-many-branches,protected-access +"""API for graph traversing.""" +import threading + +import topi + +from tvm import relay, autotvm +from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple +from tvm.relay.ty import TupleType, TensorType +from tvm.autotvm.task import TaskExtractEnv + +from .._base import RULE_OUT_NODE_NAMES +from .utils import has_multiple_inputs, is_input_node + + +# Setup relay op base name -> topi compute functions +# NOTE: To add more ops, change the following dictionary. +OP2COMPUTE = { + "conv2d" : [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw], +} + + +def expr2graph(expr, target_ops, node_dict, node_list): + """Convert relay expr to graph data structure + and fetch workloads of target operators. + + Parameters + ---------- + expr : tvm.relay.Expr.Function + Input relay function expression. + + target_ops: List of str + List of target relay base op name + + node_dict : dictionary from tvm.relay.Expr to int + Dictionary to record node index + + node_list : list of dictionary + List of nodes which contains all expr in the input relay function. + Each node will be stored as a dictionary in the format of + {"op": str, "node": tvm.relay.expr, "inputs": [int], "types": [tvm.relay.Type], + "name": str, "workloads": [tuple], "topi_op": [function]} + """ + env = TaskExtractEnv.get(allow_duplicate=True) + topi_funcs = [] + for op_name in target_ops: + if op_name not in OP2COMPUTE: + raise RuntimeError("Not supported relay op in graph tuner: %s" + % op_name) + topi_funcs += OP2COMPUTE[op_name] + env.reset(topi_funcs) + _expr2graph_impl(expr, target_ops, node_dict, node_list) + task_pos = 0 + for node_entry in node_list: + if node_entry["op"] in target_ops: + task_name, args = env.task_collection[task_pos] + task = autotvm.task.create(task_name, args, + target="llvm", + target_host=None, + template_key='direct') + node_entry["workloads"] = [task.workload] + node_entry["topi_op"] = [task_name] + task_pos += 1 + + +def _expr2graph_impl(expr, target_ops, node_dict, node_list): + """Implementation to convert relay expr to graph data structure + """ + def _traverse_expr(node): + if node in node_dict: + return + node_index = len(node_list) + node_entry = {"node": node, "inputs": [], "types": [], + "op": "null", "name": None} + + if isinstance(node, Call): + op_name = node.op.name.split(".")[-1] + node_entry["op"] = op_name + for arg in node.args: + in_node_idx = node_dict[arg] + if isinstance(arg, (Tuple, TupleGetItem)): + node_entry["inputs"] += node_list[in_node_idx]["inputs"] + else: + node_entry["inputs"].append([in_node_idx, 0, 0]) + infer_out = relay.ir_pass.infer_type(node) + out_type = infer_out._checked_type_ + if isinstance(out_type, TensorType): + node_entry["types"].append(out_type) + elif isinstance(out_type, TupleType): + for tupe_type in out_type.fields: + node_entry["types"].append(tupe_type) + else: + raise RuntimeError("Unsupported output type %s in operator %s" + % (type(out_type), op_name)) + + # Utilize tracing target to fetch workload with topo-order. + # Since we only need workload, dummy target can be used to + # create task. + if op_name in target_ops: + params = [] + for i, input_idx in enumerate(node_entry["inputs"]): + input_node_entry = node_list[input_idx[0]] + input_type = input_node_entry["types"][input_idx[1]] + if not isinstance(input_node_entry["node"], (Var, Call)): + raise RuntimeError("Graph tuner can only tune target " + "operators with input node of type " + "relay.expr.Var or relay.expr.Call. Now " + "find a target op %s with input type %s" + % (op_name, str(type(input_node_entry["node"])))) + free_var = relay.Var("var_%d" % i, input_type) + params.append(free_var) + call = relay.Call(node.op, params, node.attrs) + func = relay.Function(params, call) + relay.backend.compile_engine.get().clear() + build_thread = threading.Thread(target=relay.build, + args=(func, + "llvm -device=tracing", + None, + None)) + build_thread.start() + build_thread.join() + elif isinstance(node, Var): + node_entry["name"] = node.name_hint + node_entry["types"] = [node.type_annotation] + elif isinstance(node, Function): + # Ignore root node since it equals to input function expression + if node != expr: + _expr2graph_impl(node, target_ops, node_dict, node_list) + return + elif isinstance(node, TupleGetItem): + node_entry["op"] = "TupleGetItem" + in_node_idx = node_dict[node.tuple_value] + node_entry["inputs"].append([in_node_idx, node.index, 0]) + elif isinstance(node, Tuple): + node_entry["op"] = "Tuple" + for tuple_item in node: + in_node_idx = node_dict[tuple_item] + if isinstance(tuple_item, TupleGetItem): + node_entry["inputs"] += node_list[in_node_idx]["inputs"] + elif isinstance(tuple_item, Tuple): + raise RuntimeError("Graph tuner doesn't support nested tuple.") + else: + node_entry["inputs"].append([in_node_idx, 0, 0]) + elif isinstance(node, Constant): + pass + elif isinstance(node, relay.op.op.Op): + return + else: + raise RuntimeError("Not supported relay node type in graph tuning: %s" + % str(type(node))) + node_dict[node] = node_index + node_list.append(node_entry) + + relay.ir_pass.post_order_visit(expr, _traverse_expr) + + +def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names): + """Given a node_list in relay function and a node index, return the + closest ancestor which has op_name as operator name or is multi_input operator. + + If node has multiple inputs, multiple ancestor nodes will be returned. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + visited_dict : dict of int to int + Nodes and corresponding ancestors which have been visited. + + target_ops: List of str + List of target relay base op name + + node_idx : int + Input node index. + + input_names : list of str + Names of graph input nodes. + + Returns + ------- + out : list of int + List of ancestor node index. + """ + if node_idx in visited_dict: + return visited_dict[node_idx] + if is_input_node(node_list[node_idx], input_names): + return [node_idx] + node = node_list[node_idx] + # Rule out injective operators + is_rule_out = False + for item_idx in node["inputs"]: + item = node_list[item_idx[0]] + if item["op"] in RULE_OUT_NODE_NAMES: + is_rule_out = True + break + if is_rule_out: + visited_dict[node_idx] = [] + return [] + + node_direct_ancestor = [] + for item_idx in node["inputs"]: + item = node_list[item_idx[0]] + is_multiple_inputs = has_multiple_inputs(node_list, item_idx[0], input_names) + if item["op"] in target_ops or is_multiple_inputs: + node_direct_ancestor.append(item_idx[0]) + else: + tmp = get_direct_ancestor(node_list, visited_dict, target_ops, + item_idx[0], input_names) + for tmp_item in tmp: + node_direct_ancestor.append(tmp_item) + if not has_multiple_inputs(node_list, node_idx, input_names) and node_direct_ancestor: + node_direct_ancestor = [node_direct_ancestor[0]] + visited_dict[node_idx] = node_direct_ancestor + return node_direct_ancestor + + +def get_in_nodes(node_list, target_ops, input_names): + """Create a dictionary mapping from op_name nodes or multi_input + nodes to closest input ancestors. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + target_ops: List of str + List of target relay op + + input_names : list of str + Names of graph input nodes. + + Returns + ------- + out : dict of int to list of int + Dictionary maps node index to closest input ancestors. + """ + + visited_dict = {} + in_node_dict = {} + for i, node in enumerate(node_list): + if node["op"] in RULE_OUT_NODE_NAMES: + continue + get_direct_ancestor(node_list, visited_dict, target_ops, i, input_names) + for key, val in visited_dict.items(): + node = node_list[key] + is_multiple_inputs = has_multiple_inputs(node_list, key, input_names) + if node["op"] in target_ops or is_multiple_inputs: + in_node_dict[key] = val + + # Remove empty nodes + has_empty_node = True + out_node_dict = get_out_nodes(in_node_dict) + while has_empty_node: + empty_nodes = [] + for key, val in in_node_dict.items(): + if not val: + empty_nodes.append(key) + if empty_nodes: + has_empty_node = True + for node in empty_nodes: + del in_node_dict[node] + if node in out_node_dict: + for out_node in out_node_dict[node]: + in_node_dict[out_node].remove(node) + else: + has_empty_node = False + + return in_node_dict + + +def get_out_nodes(in_node_dict): + """Create output dictionary from input dictionary. + + Parameters + ---------- + in_node_dict : dict of int to list of int + Dictionary maps node index to closest input ancestors. + It can be created with get_in_nodes. + + Returns + ------- + out : dict of int to list of int + Dictionary maps node index to closest output nodes. + """ + out_node_dict = {} + for key in in_node_dict: + out_node_dict[key] = [] + for key, val in in_node_dict.items(): + for item in val: + if item in out_node_dict: + out_node_dict[item].append(key) + else: + out_node_dict[item] = [key] + + return out_node_dict diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py new file mode 100644 index 000000000000..6151734299af --- /dev/null +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used,invalid-name,too-many-arguments +"""Utility functions""" +from tvm import relay + + +def has_multiple_inputs(node_list, node_idx, input_names): + """Check whether a node has multiple input nodes + except variable nodes. + + Parameters + ---------- + node_list : list of dict of str to object + List of all nodes in a graph. + + node_idx : int + Node index to be checked. + + input_names : list of str + List of input names of graph. + + Returns + ------- + out : bool + Whether the specified node has multiple input nodes + """ + num_inputs = 0 + node = node_list[node_idx] + for in_idx in node["inputs"]: + in_idx = in_idx[0] + in_node = node_list[in_idx] + # Exclude parameter nodes + if in_node["op"] != "null" or is_input_node(in_node, + input_names): + num_inputs += 1 + return num_inputs > 1 + + +def is_input_node(node_entry, input_names): + """Whether a node is an input node. + + Parameters + ---------- + node_entry : dict + Node entry. + + input_names : list of str + List of input names of graph. + + Returns + ------- + out : bool + whether node is a input node. + """ + return "name" in node_entry and node_entry["name"] in input_names + + +def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): + """Bind input variables of a relay function expression + to new shapes and/or dtypes. + + Parameters + ---------- + expr : tvm.relay.Expr.Function + Input relay function expression. + + input_shapes : dict of str to tuple of int, optional + Input shapes. + + input_dtypes : str or dict of str to str, optional + Input dtypes. + + Returns + ------- + out : tvm.relay.Expr.Function + Bind relay function expression. + """ + if input_shapes is None: + return expr + if isinstance(input_dtypes, str): + input_dtypes = {key : input_dtypes for key in input_shapes.keys()} + + updated_input_dict = {} + for input_name in input_shapes.keys(): + updated_input = relay.var(input_name, shape=input_shapes[input_name], + dtype=input_dtypes[input_name]) + updated_input_dict[input_name] = updated_input + + rebind_dict = {} + for var in expr.params: + if var.name_hint in updated_input_dict: + rebind_dict[var] = updated_input_dict[var.name_hint] + updated_expr = relay.expr.bind(expr, rebind_dict) + + return relay.ir_pass.infer_type(updated_expr) diff --git a/python/tvm/autotvm/record.py b/python/tvm/autotvm/record.py index 4c0f98347d4b..14efb7bd9239 100644 --- a/python/tvm/autotvm/record.py +++ b/python/tvm/autotvm/record.py @@ -25,6 +25,8 @@ import pickle import json import time +import os +import itertools from collections import OrderedDict from .. import build, lower, target as _target @@ -238,6 +240,8 @@ def pick_best(in_file, out_file): """ Pick best entries from a file and store it to another file. This distill the useful log entries from a large log file. + If out_file already exists, the best entries from both + in_file and out_file will be saved. Parameters ---------- @@ -246,7 +250,12 @@ def pick_best(in_file, out_file): out_file: str or file The filename of output """ - best_context = ApplyHistoryBest(load_from_file(in_file)) + context = load_from_file(in_file) + if os.path.isfile(out_file): + out_context = load_from_file(out_file) + context = itertools.chain(context, out_context) + context, context_clone = itertools.tee(context) + best_context = ApplyHistoryBest(context) best_set = set() for v in best_context.best_by_model.values(): @@ -258,7 +267,7 @@ def pick_best(in_file, out_file): logger.info("Extract %d best records from the %s", len(best_set), in_file) fout = open(out_file, 'w') if isinstance(out_file, str) else out_file - for inp, res in load_from_file(in_file): + for inp, res in context_clone: if measure_str_key(inp) in best_set: fout.write(encode(inp, res) + "\n") best_set.remove(measure_str_key(inp)) diff --git a/python/tvm/autotvm/task/__init__.py b/python/tvm/autotvm/task/__init__.py index ff50a4ebc81d..0a0e6e1e8ac7 100644 --- a/python/tvm/autotvm/task/__init__.py +++ b/python/tvm/autotvm/task/__init__.py @@ -28,6 +28,7 @@ from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \ FallbackContext, clear_fallback_cache, ApplyGraphBest -from .topi_integration import register_topi_compute, register_topi_schedule +from .topi_integration import register_topi_compute, register_topi_schedule, \ + TaskExtractEnv from .nnvm_integration import extract_from_graph, extract_from_multiple_graph from .relay_integration import extract_from_program, extract_from_multiple_program diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 3c983768ab3e..c48d4f58edce 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -74,7 +74,7 @@ class TaskExtractEnv: """Global environment for extracting tuning tasks from nnvm graph""" current = None - def __init__(self): + def __init__(self, allow_duplicate=False): import topi # topi compute -> autotvm task name @@ -106,6 +106,7 @@ def __init__(self): topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], } + self.allow_duplicate = allow_duplicate self._register_tracing() self._register_topi_task() self.task_collection = [] @@ -123,10 +124,9 @@ def _tracing_topi_compute(*args, **kwargs): assert not kwargs, "Do not support extracting tuning tasks when" \ "kwargs is used in TOPI function call." \ "Please modify it to use only positional args." - if compute_func in self.wanted_topi_funcs: # record this call key = (self.topi_to_task[compute_func], serialize_args(args)) - if key not in self.task_collection: + if self.allow_duplicate or key not in self.task_collection: self.task_collection.append(key) return compute_func.fdefault(*args) _local_scope(topi_compute) @@ -262,20 +262,29 @@ def get_tasks(self): return self.task_collection @staticmethod - def get(): + def get(allow_duplicate=False): """Get the single instance of TaskExtractEnv + Parameters + ---------- + allow_duplicate : boolean + Whether to fetch all workloads in the network, + even though some of them are the same. This is + useful for graph tuning. + Returns ------- env: TaskExtractEnv The single instance of TaskExtractEnv """ if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() + TaskExtractEnv.current = TaskExtractEnv(allow_duplicate) + else: + TaskExtractEnv.current.allow_duplicate = allow_duplicate return TaskExtractEnv.current -def register_topi_compute(topi_compute, target_keys, template_keys, func=None): +def register_topi_compute(topi_compute, target_keys, template_keys, func=None, override=False): """Register a tunable template for a topi compute function. After the registration, this topi compute will become a configuration dispatcher. It uses @@ -324,7 +333,7 @@ def config_dispatcher(*args, **kwargs): config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute] - @config_dispatcher.register(template_keys) + @config_dispatcher.register(template_keys, override=override) def template_call(cfg, *args, **kwargs): """call the topi func and attach workload to compute node""" assert not kwargs, "Do not support kwargs in template function call" @@ -363,7 +372,7 @@ def template_call(cfg, *args, **kwargs): return _decorator -def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None): +def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None, override=False): """Register a tunable template for a topi schedule function. After the registration. This topi schedule will become a configuration dispatcher. It dispatches @@ -429,7 +438,7 @@ def traverse(tensors): config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule] - @config_dispatcher.register(template_keys) + @config_dispatcher.register(template_keys, override=override) def template_call(cfg, outs, *args, **kwargs): """call the schedule func""" if f == topi_schedule.fdefault: diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 120bf629a959..76170a844db1 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -143,7 +143,8 @@ class BuildConfig(NodeBase): "double_buffer_split_loop": 1, "dump_pass_ir": False, "instrument_bound_checkers": False, - "disable_select_rewriting": False + "disable_select_rewriting": False, + "disable_vectorize": False } _dump_ir = DumpIR() @@ -186,7 +187,7 @@ def __enter__(self): def __exit__(self, ptype, value, trace): if self.dump_pass_ir: BuildConfig._dump_ir.exit() - _api_internal._ExitBuildConfigScope() + _api_internal._ExitBuildConfigScope(self) def __setattr__(self, name, value): if name in BuildConfig._node_defaults: @@ -384,7 +385,10 @@ def lower(sch, # Phase 2 if not simple_mode: stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - stmt = ir_pass.VectorizeLoop(stmt) + if cfg.disable_vectorize: + stmt = ir_pass.SkipVectorize(stmt) + else: + stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) stmt = ir_pass.StorageRewrite(stmt) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index c656fcc2b966..7c024b792867 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -17,10 +17,10 @@ """External function interface to BLAS libraries.""" from __future__ import absolute_import as _abs -from .. import api as _api -from .. import intrin as _intrin +from .. import api as _api, intrin as _intrin -def matmul(lhs, rhs, transa=False, transb=False): + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): """Create an extern op that compute matrix mult of A and rhs with CrhsLAS This function serves as an example on how to call external libraries. @@ -44,7 +44,50 @@ def matmul(lhs, rhs, transa=False, transb=False): n = lhs.shape[1] if transa else lhs.shape[0] m = rhs.shape[0] if transb else rhs.shape[1] return _api.extern( - (n, m), [lhs, rhs], + (n, m), + [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): + """Create an extern op that compute batched matrix mult of A and rhs with CBLAS + This function serves as an example on how to call external libraries. + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return _api.extern( + (b, n, m), + [lhs, rhs], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.cblas.matmul", - ins[0], ins[1], outs[0], transa, transb), name="C") + "tvm.contrib.cblas.batch_matmul" + if not iterative + else "tvm.contrib.cblas.batch_matmul_iterative", + ins[0], + ins[1], + outs[0], + transa, + transb, + ), + name="C", + **kwargs + ) diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index c53a2c287339..882364dd3971 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -207,10 +207,8 @@ def dump_graph_json(self, graph): def display_debug_result(self): """Displays the debugger result" """ - header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Start Time", \ - "End Time", "Shape", "Inputs", "Outputs"] - lines = ["---------", "---", "--------", "-------", "----------", \ - "--------", "-----", "------", "-------"] + header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Shape", "Inputs", "Outputs"] + lines = ["---------", "---", "--------", "-------", "-----", "------", "-------"] eid = 0 data = [] total_time = sum(time[0] for time in self._time_list) @@ -223,12 +221,11 @@ def display_debug_result(self): continue name = node['name'] shape = str(self._output_tensor_list[eid].shape) - time_us = round(time[0] * 1000000, 2) - time_percent = round(((time[0] / total_time) * 100), 2) + time_us = round(time[0] * 1000000, 3) + time_percent = round(((time[0] / total_time) * 100), 3) inputs = str(node['attrs']['num_inputs']) outputs = str(node['attrs']['num_outputs']) - node_data = [name, op, time_us, time_percent, str(time[1]), str(time[2]), \ - shape, inputs, outputs] + node_data = [name, op, time_us, time_percent, shape, inputs, outputs] data.append(node_data) eid += 1 fmt = "" diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 01cda35769a5..f77a927eeabf 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -19,7 +19,6 @@ import os import tempfile import shutil -from datetime import datetime from tvm._ffi.base import string_types from tvm._ffi.function import get_global_func from tvm.contrib import graph_runtime @@ -30,6 +29,7 @@ _DUMP_ROOT_PREFIX = "tvmdbg_" _DUMP_PATH_PREFIX = "_tvmdbg_" + def create(graph_json_str, libmod, ctx, dump_root=None): """Create a runtime executor module given a graph and module. @@ -62,17 +62,23 @@ def create(graph_json_str, libmod, ctx, dump_root=None): try: fcreate = get_global_func("tvm.graph_runtime_debug.create") except ValueError: - raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \ - "config.cmake and rebuild TVM to enable debug mode") + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " + "config.cmake and rebuild TVM to enable debug mode" + ) ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): libmod = rpc_base._ModuleHandle(libmod) try: - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_debug.remote_create") + fcreate = ctx[0]._rpc_sess.get_function( + "tvm.graph_runtime_debug.remote_create" + ) except ValueError: - raise ValueError("Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " \ - "config.cmake and rebuild TVM to enable debug mode") + raise ValueError( + "Please set '(USE_GRAPH_RUNTIME_DEBUG ON)' in " + "config.cmake and rebuild TVM to enable debug mode" + ) func_obj = fcreate(graph_json_str, libmod, *device_type_id) return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root) @@ -100,10 +106,10 @@ class GraphModuleDebug(graph_runtime.GraphModule): To select which folder the outputs should be kept. None will make a temp folder in /tmp/tvmdbg and does the dumping """ + def __init__(self, module, ctx, graph_json_str, dump_root): self._dump_root = dump_root self._dump_path = None - self._debug_run = module["debug_run"] self._get_output_by_layer = module["get_output_by_layer"] self._run_individual = module["run_individual"] graph_runtime.GraphModule.__init__(self, module) @@ -181,13 +187,10 @@ def _run_debug(self): Time consumed for each execution will be set as debug output. """ - self.debug_datum._time_list = [] - + self.debug_datum._time_list = [ + [float(t) * 1e-6] for t in self.run_individual(10, 1, 1) + ] for i, node in enumerate(self.debug_datum.get_graph_nodes()): - start_time = datetime.now().time() - time_stamp = self._debug_run(i) - end_time = datetime.now().time() - self.debug_datum._time_list.append([time_stamp, start_time, end_time]) num_outputs = self.debug_datum.get_graph_node_output_num(node) for j in range(num_outputs): out_tensor = self._get_output_by_layer(i, j) @@ -212,8 +215,13 @@ def debug_get_output(self, node, out): ret = output_tensors[node] except: node_list = output_tensors.keys() - raise RuntimeError("Node " + node + " not found, available nodes are: " - + str(node_list) + ".") + raise RuntimeError( + "Node " + + node + + " not found, available nodes are: " + + str(node_list) + + "." + ) elif isinstance(node, int): output_tensors = self.debug_datum._output_tensor_list ret = output_tensors[node] @@ -242,7 +250,9 @@ def run(self, **input_dict): self.debug_datum.display_debug_result() def run_individual(self, number, repeat=1, min_repeat_ms=0): - self._run_individual(number, repeat, min_repeat_ms) + ret = self._run_individual(number, repeat, min_repeat_ms) + return ret.strip(",").split(",") if ret else [] + def exit(self): """Exits the dump folder and all its contents""" diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 4d0698a40db7..0c9ce404c48e 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -129,6 +129,7 @@ def __init__(self, module): self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] self._load_params = module["load_params"] + self._share_params = module["share_params"] def set_input(self, key=None, value=None, **params): """Set inputs to the module via kwargs @@ -234,6 +235,19 @@ def load_params(self, params_bytes): """ self._load_params(bytearray(params_bytes)) + def share_params(self, other, params_bytes): + """Share parameters from pre-existing GraphRuntime instance. + + Parameters + ---------- + other: GraphRuntime + The parent GraphRuntime from which this instance should share + it's parameters. + params_bytes : bytearray + The serialized parameter dict (used only for the parameter names). + """ + self._share_params(other.module, bytearray(params_bytes)) + def __getitem__(self, key): """Get internal module function diff --git a/python/tvm/datatype.py b/python/tvm/datatype.py new file mode 100644 index 000000000000..df3e3a62a510 --- /dev/null +++ b/python/tvm/datatype.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Custom datatype functionality""" +from __future__ import absolute_import as _abs + +from ._ffi.function import register_func as _register_func +from . import make as _make +from .api import convert +from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from ._ffi.runtime_ctypes import TVMType as _TVMType +from . import _api_internal + + +def register(type_name, type_code): + """Register a custom datatype with the given type name and type code + Currently, the type code is manually allocated by the user, and the + user must ensure that no two custom types share the same code. + Generally, this should be straightforward, as the user will be + manually registering all of their custom types. + + Parameters + ---------- + type_name : str + The name of the custom datatype + + type_code : int + The type's code, which should be >= kCustomBegin + """ + _api_internal._datatype_register(type_name, type_code) + + +def get_type_name(type_code): + """Get the type name from the type code + + Parameters + ---------- + type_code : int + The type code + """ + return _api_internal._datatype_get_type_name(type_code) + + +def get_type_code(type_name): + """Get the type code from the type name + + Parameters + ---------- + type_name : str + The type name + """ + return _api_internal._datatype_get_type_code(type_name) + + +def get_type_registered(type_code): + """Get a boolean representing whether the type is registered + + Parameters + ---------- + type_code: int + The type code + """ + return _api_internal._datatype_get_type_registered(type_code) + + +def register_op(lower_func, op_name, target, type_name, src_type_name=None): + """Register an external function which computes the given op. + + Currently, this will only work with Casts and binary expressions + whose arguments are named `a` and `b`. + TODO(gus) figure out what other special cases must be handled by + looking through expr.py. + + Parameters + ---------- + lower_func : function + The lowering function to call. See create_lower_func. + + op_name : str + The name of the operation which the function computes, given by its + Halide::Internal class name (e.g. Add, LE, Cast). + + target : str + The name of codegen target. + + type_name : str + The name of the custom datatype, e.g. posit (but not custom[posit]8). + + src_type_name : str + If op_name is "Cast", then this should be set to the source datatype of + the argument to the Cast. If op_name is not "Cast", this is unused. + """ + + if op_name == "Cast": + assert src_type_name is not None + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + "." + src_type_name + else: + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + _register_func(lower_func_name, lower_func) + + +def create_lower_func(extern_func_name): + """Returns a function which lowers an operation to a function call. + + Parameters + ---------- + extern_func_name : str + The name of the extern "C" function to lower to + """ + + def lower(op): + """ + Takes an op---either a Cast or a binary op (e.g. an Add) and returns a + call to the specified external function, passing the op's argument + (Cast) or arguments (a binary op). The return type of the call depends + on the type of the op: if it is a custom type, then a uint of the same + width as the custom type is returned. Otherwise, the type is + unchanged.""" + dtype = op.dtype + t = _TVMType(dtype) + if get_type_registered(t.type_code): + dtype = "uint" + str(t.bits) + if t.lanes > 1: + dtype += "x" + str(t.lanes) + if isinstance(op, (_Cast, _FloatImm)): + return _make.Call(dtype, extern_func_name, convert([op.value]), + _Call.Extern, None, 0) + return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), + _Call.Extern, None, 0) + + return lower diff --git a/python/tvm/expr.py b/python/tvm/expr.py index a234ac4da53b..9c8a9ab89d3b 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -222,7 +222,7 @@ def asnode(self): class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" - # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ + # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = NodeBase.__hash__ @@ -349,6 +349,16 @@ def __init__(self, value): self.__init_handle_by_constructor__( _make.StringImm, value) + def __eq__(self, other): + if isinstance(other, ConstExpr): + return self.value == other.value + return self.value == other + + def __ne__(self, other): + if isinstance(other, ConstExpr): + return self.value != other.value + return self.value != other + @register_node class Cast(Expr): diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 6201681e0294..5536e503e6b6 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -25,7 +25,9 @@ from . import module from . import adt from . import ir_pass -from .build_module import build, build_config, create_executor, optimize +from . import transform +from .build_module import build, create_executor +from .transform import build_config from . import prelude from . import parser from . import debug @@ -97,9 +99,9 @@ var = expr.var const = expr.const bind = expr.bind -module_pass = ir_pass.module_pass -function_pass = ir_pass.function_pass -sequential_pass = ir_pass.sequential_pass +module_pass = transform.module_pass +function_pass = transform.function_pass +alpha_equal = ir_pass.alpha_equal # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -114,9 +116,9 @@ load_param_dict = param_dict.load_param_dict # Pass manager -PassInfo = ir_pass.PassInfo -PassContext = ir_pass.PassContext -Pass = ir_pass.Pass -ModulePass = ir_pass.ModulePass -FunctionPass = ir_pass.FunctionPass -SequentialPass = ir_pass.SequentialPass +PassInfo = transform.PassInfo +PassContext = transform.PassContext +Pass = transform.Pass +ModulePass = transform.ModulePass +FunctionPass = transform.FunctionPass +Sequential = transform.Sequential diff --git a/python/tvm/relay/_build_module.py b/python/tvm/relay/_build_module.py new file mode 100644 index 000000000000..bdbcbefff523 --- /dev/null +++ b/python/tvm/relay/_build_module.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable +"""The interface for building Relay functions exposed from C++.""" +from tvm._ffi.function import _init_api + +_init_api("relay.build_module", __name__) diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index 6aedb5248657..13035bb36f71 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -17,62 +17,8 @@ import tvm from . import ir -from .base import NodeBase from .env import Module - -class PassContext(NodeBase): - def __init__(self): - ... - -class PassInfo(NodeBase): - name = ... # type: str - opt_level = ... # type: int - required = ... # type: list - - def __init__(self, name, opt_level, required) - # type: (str, int, list) -> None - - -class Pass(NodeBase): - def __init__(self): - ... - - -class ModulePass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class FunctionPass(Pass): - name = ... # type: str - opt_level = ... # type: int - pass_func = ... # type: Callable - required = ... # type: list - - def __init__(self, name, opt_level, pass_func, required): - # type: (str, int, Callable, list) -> None - ... - - -class SequentialPass(Pass): - name = ... # type: str - opt_level = ... # type: int - passes = ... # type: list - required = ... # type: list - disabled = ... # type: list - - def __init__(self, name, opt_level, passes, required, disabled): - # type: (str, int, list, list, list) -> None - ... - - def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 62f0ffe15cba..303f694896a5 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -242,10 +242,12 @@ def visitProg(self, ctx): self.visit_list(ctx.defn()) return self.module - return self.visit(ctx.expr()) + if ctx.expr(): + return self.visit(ctx.expr()) - # Exprs + return self.module + # Exprs def visitOpIdent(self, ctx): # type: (RelayParser.OpIdentContext) -> op.Op return op.get(ctx.CNAME().getText()) @@ -368,14 +370,25 @@ def mk_func(self, ctx): self.enter_var_scope() # Capture type params in params. self.enter_type_param_scope() + type_params = ctx.typeParamSeq() + + if type_params is not None: + type_params = type_params.ident() + assert type_params + for ty_param in type_params: + name = ty_param.getText() + self.mk_typ(name, ty.Kind.Type) + var_list, attr_list = self.visit(ctx.argList()) ret_type = self.getType_(ctx.type_()) + body = self.visit(ctx.body()) + # NB(@jroesch): you must stay in the type parameter scope until + # after you exit the body, you can reference the type parameters + # of your parent scopes. type_params = list(self.exit_type_param_scope()) if type_params: _, type_params = zip(*type_params) - - body = self.visit(ctx.body()) self.exit_var_scope() attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None @@ -453,16 +466,23 @@ def visitIncompleteType(self, ctx): # type (RelayParser.IncompleteTypeContext) -> None: return None - def visitIdentType(self, ctx): - # type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str] - ident_type = ctx.CNAME().getText() + def visitTypeIdent(self, ctx): + # type: (RelayParser.TypeIdentContext) -> Union[ty.TensorType, str] + ''' + Handle type identifier. + ''' + type_ident = ctx.CNAME().getText() - # look through all type prefixes for a match + # Look through all type prefixes for a match for type_prefix in TYPE_PREFIXES: - if ident_type.startswith(type_prefix): - return ty.scalar_type(ident_type) + if type_ident.startswith(type_prefix): + return ty.scalar_type(type_ident) + + type_param = lookup(self.type_param_scopes, type_ident) + if type_param is not None: + return type_param - raise ParseError("Unknown builtin type: {}".format(ident_type)) + raise ParseError("Unknown builtin type: {}".format(type_ident)) # def visitCallType(self, ctx): # # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType] diff --git a/python/tvm/relay/_transform.py b/python/tvm/relay/_transform.py new file mode 100644 index 000000000000..273d97e0962a --- /dev/null +++ b/python/tvm/relay/_transform.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI exposing the Relay type inference and checking.""" + +from tvm._ffi.function import _init_api + +_init_api("relay._transform", __name__) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 50e9694b40df..860788a4e5d0 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,7 +17,6 @@ """The interface of expr function exposed from C++.""" from __future__ import absolute_import -import logging from ... import build_module as _build from ... import container as _container from ..._ffi.function import _init_api, register_func @@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func): # pylint: disable=broad-except try: f = _build.lower(sch, inputs, name=func_name) - logging.debug("lower function %s", func_name) - logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) + # logging.debug("lower function %s", func_name) + # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) except Exception: msg = traceback.format_exc() msg += "Error during compile function\n" diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index ea1846b93beb..cf31e9cff833 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -36,12 +36,9 @@ from __future__ import absolute_import from tvm.ndarray import empty -from tvm._ffi.function import _init_api - from tvm.relay import build_module from tvm import target as _target - -_init_api("tvm.relay.build_module") +from tvm import expr as _expr class GraphRuntimeCodegen(object): """The compiler from Relay to the TVM runtime system.""" @@ -57,17 +54,14 @@ def __init__(self, mod, target): self._setup(mod, target) def _setup(self, mod, target): - tgts = [] + tgts = {} if isinstance(target, dict): - for kv in target.items(): - tgts.append(kv[0]) - if isinstance(kv[1], (str, _target.Target)): - tgts.append(str(kv[1])) - else: + for dev, tgt in target.items(): + if not isinstance(tgt, (str, _target.Target)): raise Exception("Unknown target type") + tgts[dev] = _target.create(tgt) elif isinstance(target, (str, _target.Target)): - tgts.append("0") - tgts.append(str(target)) + tgts[_expr.IntImm("int32", 0)] = _target.create(target) self._init(mod, tgts) def codegen(self, func): diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index fc47f4e1b7c8..c54a65b78fb2 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -21,7 +21,8 @@ import numpy as np from . import _backend -from .. import _make, ir_pass +from .. import _make, ir_pass, transform +from .. import module from ... import register_func, nd from ..base import NodeBase, register_relay_node from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const @@ -73,9 +74,9 @@ class Closure(Value): @register_relay_node class ConstructorValue(Value): - def __init__(self, constructor, fields, types): + def __init__(self, tag, fields, constructor, types): self.__init_handle_by_constructor__( - _make.ConstructorValue, constructor, fields, types) + _make.ConstructorValue, tag, fields, constructor, types) @register_relay_node @@ -118,6 +119,8 @@ def _arg_to_ast(arg): return Constant(arg.data.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): return Tuple([_arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, tuple): + return Tuple([_arg_to_ast(field) for field in arg]) elif isinstance(arg, RefValue): return RefCreate(_arg_to_ast(arg.value)) elif isinstance(arg, ConstructorValue): @@ -189,14 +192,14 @@ def _convert_args(self, expr, args, kwargs): return tuple(cargs) - def _make_executor(self, _): + def _make_executor(self, expr=None): """ Construct a Python function that implements the evaluation of expression. Parameters ---------- - expr: relay.Expr + expr: Optional[relay.Expr] The Relay expression to execute. Returns @@ -206,16 +209,16 @@ def _make_executor(self, _): """ raise NotImplementedError() - def evaluate(self, expr, binds=None): + def evaluate(self, expr=None, binds=None): """ Evaluate a Relay expression on the executor. Parameters ---------- - expr: tvm.relay.Expr + expr: Optional[tvm.relay.Expr] The expression to evaluate. - binds: Map[tvm.relay.Var, tvm.relay.Expr] + binds: Optional[Map[tvm.relay.Var, tvm.relay.Expr]] Additional binding of free variable. Returns @@ -230,6 +233,9 @@ def evaluate(self, expr, binds=None): scope_builder.ret(expr) expr = scope_builder.get() + if not expr: + return self._make_executor() + if isinstance(expr, Function): assert not ir_pass.free_vars(expr) @@ -262,46 +268,47 @@ def __init__(self, mod, ctx, target): self.target = target self._intrp = _backend.CreateInterpreter(mod, ctx, target) - def optimize(self, expr): - """Optimize an expr. - - Parameters - ---------- - expr : Expr - The expression to be optimized. + def optimize(self): + """Optimize functions in a module. Returns ------- - opt_expr : Expr - The optimized expression. + opt_mod : tvm.relay.Module + The optimized module. """ - # TODO: We need to move this optimization code into the optimizer/pass manager - wrapped_expr = expr if isinstance(expr, Function) else Function([], expr) - if self.mod: - self.mod[self.mod.entry_func] = wrapped_expr - ck_expr = ir_pass.infer_type(wrapped_expr, mod=self.mod) - simp_expr = ir_pass.simplify_inference(ck_expr) - ck_simp = ir_pass.infer_type(simp_expr, mod=self.mod) - fused_expr = ir_pass.fuse_ops(ck_simp, 0, mod=self.mod) - ck_fused = ir_pass.infer_type(fused_expr, mod=self.mod) - return ck_fused if isinstance(expr, Function) else Call(ck_fused, []) - - def _make_executor(self, expr): + seq = transform.Sequential([transform.SimplifyInference(), + transform.FuseOps(0), + transform.InferType()]) + return seq(self.mod) + + def _make_executor(self, expr=None): + if expr is None or isinstance(expr, GlobalVar): + assert self.mod is not None def _interp_wrapper(*args, **kwargs): - args = self._convert_args(expr, args, kwargs) + if expr is None: + args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) + else: + args = self._convert_args(expr, args, kwargs) relay_args = [] for arg in args: relay_args.append(_arg_to_ast(arg)) - if isinstance(expr, GlobalVar): - func = self.mod[expr] - func = self.optimize(func) - self.mod._add(expr, func, True) - opt_expr = Call(expr, relay_args) - return self._intrp(opt_expr) + # Set the entry function for the module. + if expr is None: + pass + elif isinstance(expr, GlobalVar): + self.mod[self.mod.entry_func] = self.mod[expr] else: - call = Call(expr, relay_args) - opt_expr = self.optimize(call) - return self._intrp(opt_expr) + assert isinstance(expr, Function) + func = Function([], Call(expr, relay_args)) + relay_args = [] + if self.mod: + self.mod[self.mod.entry_func] = func + else: + self.mod = module.Module.from_expr(func) + + mod = self.optimize() + opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args) + return self._intrp(opt_expr) return _interp_wrapper diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index bebadd167fe9..ceb403fe7717 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -20,24 +20,45 @@ Implements a Python interface to compiling and executing on the Relay VM. """ +import numpy as np + import tvm from tvm._ffi.function import Object -import numpy as np -from .. import ir_pass +from .. import transform from ..backend.interpreter import Executor -from ..expr import GlobalVar, Function, Expr +from ..expr import GlobalVar, Expr from . import _vm Object = Object -def optimize(expr, mod=None): - # TODO: We need to move this optimization code into the optimizer/pass manager - ck_expr = ir_pass.infer_type(expr, mod=mod) - simplified_expr = ir_pass.simplify_inference(ck_expr) - simplified_expr = ir_pass.infer_type(simplified_expr, mod=mod) - fused_expr = ir_pass.fuse_ops(simplified_expr, mod=mod) - ck_fused = ir_pass.infer_type(fused_expr, mod=mod) - return ck_fused +def optimize(mod): + """Perform several optimizations on a module before executing it in the + Relay virtual machine. + + Parameters + ---------- + mod : tvm.relay.Module + The module to optimize. + + Returns + ------- + ret : tvm.relay.Module + The optimized module. + """ + main_func = mod[mod.entry_func] + + opt_passes = [] + if not main_func.params and isinstance(main_func.body, GlobalVar): + opt_passes.append(transform.EtaExpand()) + + opt_passes = opt_passes + [ + transform.SimplifyInference(), + transform.FuseOps(), + transform.InferType() + ] + + seq = transform.Sequential(opt_passes) + return seq(mod) def _convert(arg, cargs): if isinstance(arg, np.ndarray): @@ -76,15 +97,7 @@ def _eval_vm(mod, ctx, *args): args: List[tvm.NDArray, np.ndarray] The arguments to evaluate. """ - main_func = mod[mod.entry_func] - - if not main_func.params and isinstance(main_func.body, GlobalVar): - main_func = ir_pass.eta_expand(main_func.body, mod) - - assert isinstance(main_func, Function) - main_func = optimize(mod[mod.entry_func], mod) - mod[mod.entry_func] = main_func - + mod = optimize(mod) args = list(args) assert isinstance(args, list) cargs = convert(args) @@ -117,9 +130,11 @@ def __init__(self, mod, ctx, target): self.ctx = ctx self.target = target - def _make_executor(self, expr): - assert isinstance(expr, Expr) - self.mod[self.mod.entry_func] = expr + def _make_executor(self, expr=None): + expr = expr if expr else self.mod + assert expr, "either expr or self.mod should be not null." + if isinstance(expr, Expr): + self.mod[self.mod.entry_func] = expr main = self.mod[self.mod.entry_func] def _vm_wrapper(*args, **kwargs): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c8b69e011543..1aa4d5ae57c4 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -18,219 +18,128 @@ Construct the necessary state for the TVM graph runtime from a Relay expression. """ -import warnings +import numpy as np -from tvm._ffi.runtime_ctypes import TVMContext -from ..build_module import build as _tvm_build_module +from tvm import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt +from . import _build_module from . import ir_pass -from . import expr as _expr from . import ty as _ty +from . import expr as _expr from .backend import interpreter as _interpreter -from .backend import graph_runtime_codegen as _graph_gen from .backend.vm import VMExecutor -# List of optimization pass and level when switch on -OPT_PASS_LEVEL = { - "SimplifyInference": 0, - "OpFusion": 1, - "FoldConstant": 2, - "CombineParallelConv2D": 3, - "FoldScaleAxis": 3, - "AlterOpLayout": 3, - "CanonicalizeOps": 3, - "EliminateCommonSubexpr": 3, -} +def _update_target(target): + target = target if target else _target.current_target() + if target is None: + raise ValueError("Target is not set in env or passed as argument.") + tgts = {} + if isinstance(target, (str, _target.Target)): + dev_type = tvm_expr.IntImm("int32", _nd.context(str(target)).device_type) + tgts[dev_type] = _target.create(target) + elif isinstance(target, dict): + for dev, tgt in target.items(): + dev_type = tvm_expr.IntImm("int32", _nd.context(dev).device_type) + tgts[dev_type] = _target.create(tgt) + else: + raise TypeError("target is expected to be str or " + + "tvm.target.Target, but received " + + "{}".format(type(target))) + return tgts -class BuildConfig(object): - """Configuration scope to set a build config option. - Parameters - ---------- - kwargs - Keyword arguments of configurations to set. +class BuildModule(object): + """Build a Relay function to run on TVM graph runtime. This class is used + to expose the `RelayBuildModule` APIs implemented in C++. """ - current = None - defaults = { - "opt_level": 2, - "add_pass": None, - "fallback_device": None, - } - - def __init__(self, **kwargs): - self._old_scope = None - for k, _ in kwargs.items(): - if k not in BuildConfig.defaults: - raise ValueError("invalid argument %s, candidates are %s" % - (k, BuildConfig.defaults.keys())) - self._attr = kwargs - - def __getattr__(self, name): - if name not in self._attr: - return BuildConfig.defaults[name] - return self._attr[name] - - def __enter__(self): - # pylint: disable=protected-access - self._old_scope = BuildConfig.current - attr = BuildConfig.current._attr.copy() - attr.update(self._attr) - self._attr = attr - BuildConfig.current = self - return self - - def __exit__(self, ptype, value, trace): - assert self._old_scope - BuildConfig.current = self._old_scope - - def pass_enabled(self, pass_name): - """Get whether pass is enabled. - + def __init__(self): + self.mod = _build_module._BuildModule() + self._get_graph_json = self.mod["get_graph_json"] + self._get_module = self.mod["get_module"] + self._build = self.mod["build"] + self._set_params_func = self.mod["set_params"] + self._get_params_func = self.mod["get_params"] + + def build(self, func, target=None, target_host=None, params=None): + """ Parameters ---------- - pass_name : str - The optimization pass name + func: relay.Function + The function to build. + + target : str, :any:`tvm.target.Target`, or dict of str(i.e. + device/context name) to str/tvm.target.Target, optional + For heterogeneous compilation, it is a dictionary indicating context + to target mapping. For homogeneous compilation, it is a build target. + + target_host : str or :any:`tvm.target.Target`, optional + Host compilation target, if target is device. + When TVM compiles device specific program such as CUDA, + we also need host(CPU) side code to interact with the driver + to setup the dimensions and parameters correctly. + target_host is used to specify the host side codegen target. + By default, llvm is used if it is enabled, + otherwise a stackvm intepreter is used. + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. Returns ------- - enabled : bool - Whether pass is enabled. - """ - if self.add_pass and pass_name in self.add_pass: - return True - return self.opt_level >= OPT_PASS_LEVEL[pass_name] - - -BuildConfig.current = BuildConfig() - - -def build_config(**kwargs): - """Configure the build behavior by setting config variables. - - Parameters - ---------- - opt_level: int, default=2 - Optimization level. See OPT_PASS_LEVEL for level of each pass. - - add_pass: set of str - Optimization pass to be added regardless of optimization level. - - fallback_device : str or tvm.TVMContext - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - - Returns - ------- - config: BuildConfig - The build configuration - """ - return BuildConfig(**kwargs) - - -def _bind_params_by_name(func, params): - """Bind parameters of function by its name.""" - name_dict = {} - for arg in func.params: - name = arg.name_hint - if name in name_dict: - name_dict[name] = None - else: - name_dict[name] = arg - bind_dict = {} - for k, v in params.items(): - if k not in name_dict: - continue - arg = name_dict[k] - if arg is None: - raise ValueError("Multiple args in the function have name %s" % k) - bind_dict[arg] = _expr.const(v) - return _expr.bind(func, bind_dict) - - -def optimize(func, target=None, params=None): - """Perform target invariant optimizations. + graph_json : str + The json string that can be accepted by graph runtime. - Parameters - ---------- - func : tvm.relay.Function - The input to optimization. + mod : tvm.Module + The module containing necessary libraries. - target : Optional[:any:`tvm.target.Target`, Dict[int, tvm.target.Target]] - The optimization target. For heterogeneous compilation, it is a - dictionary mapping device type to compilation target. For homogeneous - compilation, it is a build target. - - params : Optional[Dict[str, tvm.nd.NDArray]] - Input parameters to the graph that do not change - during inference time. used for constant folding. + params : dict + The parameters of the final graph. + """ + target = _update_target(target) - Returns - ------- - opt_func : tvm.relay.Function - The optimized version of the function. - """ - cfg = BuildConfig.current - - # bind expressions - if params: - func = _bind_params_by_name(func, params) - - if cfg.pass_enabled("SimplifyInference"): - func = ir_pass.infer_type(func) - func = ir_pass.simplify_inference(func) - - if cfg.pass_enabled("EliminateCommonSubexpr"): - def fskip(expr): - if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \ - expr.attrs.dtype == 'int32': - return True - return False - - func = ir_pass.infer_type(func) - func = ir_pass.eliminate_common_subexpr(func, fskip) - - if cfg.pass_enabled("CombineParallelConv2D"): - func = ir_pass.infer_type(func) - func = ir_pass.combine_parallel_conv2d(func) - - # The constant folding pass is necessary because FoldScaleAxis pass needs - # to check the constantness and positiveness of scales. - if cfg.pass_enabled("FoldConstant"): - func = ir_pass.fold_constant(func) - - if cfg.pass_enabled("FoldScaleAxis"): - func = ir_pass.infer_type(func) - func = ir_pass.backward_fold_scale_axis(func) - func = ir_pass.infer_type(func) - func = ir_pass.forward_fold_scale_axis(func) - func = ir_pass.fold_constant(func) - - if cfg.pass_enabled("CanonicalizeOps"): - func = ir_pass.infer_type(func) - func = ir_pass.canonicalize_ops(func) - - # FIXME(zhiics) Skip AlterOpLayout pass for heterogeneous compilation for - # now. We probably need to pass target to this pass as well. Fix it in - # a followup PR. - if cfg.pass_enabled("AlterOpLayout"): - if isinstance(target, _target.Target): - func = ir_pass.infer_type(func) - with target: - func = ir_pass.alter_op_layout(func) - elif isinstance(target, dict): - warnings.warn("AlterOpLayout pass is not enabled for heterogeneous" - " execution yet.") - - if cfg.pass_enabled("FoldConstant"): - func = ir_pass.fold_constant(func) - - return func + # Setup the params. + if params: + self._set_params(params) + # Build the function + self._build(func, target, target_host) + # Get artifacts + graph_json = self.get_json() + mod = self.get_module() + params = self.get_params() + + return graph_json, mod, params + + def _set_params(self, params): + inputs = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = _nd.array(param) + inputs[name] = _expr.const(param) + self._set_params_func(inputs) + + def get_json(self): + """Return the json file of the built program.""" + return self._get_graph_json() + + def get_module(self): + """Return the built module.""" + return self._get_module() + + def get_params(self): + """Return the updated weights.""" + params = self._get_params_func() + ret = {} + for key, value in params.items(): + ret[key] = value.data + return ret def build(func, target=None, target_host=None, params=None): - """Build a function to run on TVM graph runtime. + """Helper function that builds a Relay function to run on TVM graph + runtime. Parameters ---------- @@ -266,146 +175,28 @@ def build(func, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - target = target if target else _target.current_target() - if target is None: - raise ValueError("Target is not set in env or passed as argument.") + target = _update_target(target) - if isinstance(target, dict): - target, fallback_device = _update_heterogeneous_inputs(target) - elif isinstance(target, (str, _target.Target)): - target = _target.create(target) - else: - raise ValueError("target must be the type of str, tvm.target.Target," + - "or dict of device name to target") + if isinstance(target_host, (str, _target.Target)): + target_host = _target.create(target_host) + elif target_host: + raise ValueError("target host must be the type of str, " + + "tvm.target.Target, or None") # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): - if isinstance(target, dict): - tophub_context = autotvm.tophub.context(list(target.values())) - else: - tophub_context = autotvm.tophub.context(target) + tophub_context = autotvm.tophub.context(list(target.values())) else: tophub_context = autotvm.util.EmptyContext() - cfg = BuildConfig.current - with tophub_context: - func = optimize(func, target, params) - # Annotate the ops for heterogeneous execution. - if isinstance(target, dict): - func, target = _run_device_annotation_passes(func, target, - fallback_device) - # Fuse ops before running code gen - func = ir_pass.infer_type(func) - func = ir_pass.fuse_ops(func, cfg.opt_level) - # Graph code generation - func = ir_pass.infer_type(func) - graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) - graph_json, lowered_funcs, params = graph_gen.codegen(func) - mod = _tvm_build_module( - lowered_funcs, target=target, target_host=target_host) + bld_mod = BuildModule() + graph_json, mod, params = bld_mod.build(func, target, target_host, + params) return graph_json, mod, params -def _update_heterogeneous_inputs(target): - """Update the target and fallback device required for heterogeneous - compilation. CPU is used as the fallback device if it wasn't provided. - Meanwhile, a CPU device type and "llvm" pair will be added to the target - dictionary in this case. - - Parameters - ---------- - target : dict of str(i.e. device/context name) to str/tvm.target.Target. - A dict contains context to target pairs. - - Returns - ------- - device_target : dict of int to tvm.target.Target. - The updated device type to target dict. - - fallback_device : int - The updated fallback device type. - """ - if not isinstance(target, dict): - raise ValueError("target must be dict of device name to target for " + - "heterogeneous execution, but received %s." - % type(target)) - - fallback_device = BuildConfig.current.fallback_device - if fallback_device is None: - # cpu is used as the default fallback device when heterogeneous - # execution is needed, but no fallback device is provided. - fallback_device = _nd.cpu(0).device_type - target[fallback_device] = str(_target.create("llvm")) - elif isinstance(fallback_device, str): - fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, TVMContext): - fallback_device = fallback_device.device_type - else: - raise ValueError("fallback_device expects the type of str or " + - "TVMContext, but received %s." % type(fallback_device)) - - device_target = {} - for dev, tgt in target.items(): - device_target[_nd.context(dev).device_type] = _target.create(tgt) - - if fallback_device not in device_target: - raise ValueError("%s is used as the default device, but the target" + - "is not provided." - % _nd.context(fallback_device).device_name) - return device_target, fallback_device - - -def _run_device_annotation_passes(func, target, fallback_device): - """Execute the device annotation passes to update the input program and - target information. - - Parameters - ---------- - func: tvm.relay.Function - The function where annotation passes will be execute at. - - target : Dict[int, tvm.target.Target] - A dict contains device type to target pairs. - - fallback_device : int - The fallback device type. - - Returns - ------- - target : Dict[int, tvm.target.Target] - The updated device type to target dict. - - func : tvm.relay.Function - The updated func. - """ - func = ir_pass.infer_type(func) - func = ir_pass.rewrite_annotated_ops(func, fallback_device) - device_map = ir_pass.collect_device_info(func) - # The expression to device type map will be empty if all or none of - # the expressions in the `func` are annotated because this map is - # obtained by propagating the device information in the device copy - # operator. None of the above cases needs device copy operator. - if not device_map: - annotation_map = ir_pass.collect_device_annotation_ops(func) - # No annotation. - if not annotation_map: - target = {0: target[fallback_device]} - else: - dev_type = next(iter(annotation_map.values())) - # All annotated with the same device type. - if all(val == dev_type for val in annotation_map.values()): - target = {0: target[dev_type]} - else: - raise RuntimeError("Expressions in the function are " - "annotated with various device types," - "but not device copy operators " - "found. Please check the " - "RewriteAnnotation pass.") - return func, target - - class GraphExecutor(_interpreter.Executor): """Wrapper around Executor interface. @@ -428,16 +219,19 @@ def __init__(self, mod, ctx, target): self.ctx = ctx self.target = target - def _make_executor(self, func): - ret_type = ir_pass.infer_type(func).ret_type + def _make_executor(self, expr=None): + if not expr: + assert self.mod, "either expr or self.mod should be not null." + expr = self.mod[self.mod.entry_func] + ret_type = ir_pass.infer_type(expr).ret_type num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 - graph_json, mod, params = build(func, target=self.target) + graph_json, mod, params = build(expr, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: gmodule.set_input(**params) def _graph_wrapper(*args, **kwargs): - args = self._convert_args(func, args, kwargs) + args = self._convert_args(expr, args, kwargs) # Create map of inputs. for i, arg in enumerate(args): gmodule.set_input(i, arg) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 98b4a83e09de..8e7f95c4dc26 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -70,6 +70,38 @@ def astype(self, dtype): def __neg__(self): return _op_make.negative(self) + def __lt__(self, other): + if isinstance(other, Expr): + return _op_make.less(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __gt__(self, other): + if isinstance(other, Expr): + return _op_make.greater(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __ge__(self, other): + if isinstance(other, Expr): + return _op_make.greater_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __le__(self, other): + if isinstance(other, Expr): + return _op_make.less_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + def __add__(self, other): if isinstance(other, Expr): return _op_make.add(self, other) diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index 8d308c7e8833..76761fd78325 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -30,3 +30,4 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow +from .darknet import from_darknet diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index e92a6226072f..18489b380ee7 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -20,6 +20,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import AttrCvt, Renamer @@ -382,6 +383,7 @@ def __init__(self, shape, dtype): self._ops = {} self._shape = shape self._dtype = dtype + self._mod = _module.Module({}) def from_caffe2(self, init_net, predict_net): """Construct Relay expression from caffe2 graph. @@ -393,8 +395,9 @@ def from_caffe2(self, init_net, predict_net): Returns ------- - func : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The module that optimizations will be performed on. + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -448,8 +451,9 @@ def from_caffe2(self, init_net, predict_net): outputs = out[0] func = _expr.Function(ir_pass.free_vars(outputs), outputs) + self._mod[self._mod.entry_func] = func - return func, self._params + return self._mod, self._params def _get_node(self, blob): """Get the Symbol of blob and detect cyclic dependency in the graph.""" @@ -505,7 +509,7 @@ def _convert_operator(self, identity_list=None, convert_map=None): """Convert from Caffe2 operator to Relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters @@ -560,8 +564,8 @@ def from_caffe2(init_net, predict_net, shape=None, dtype="float32"): Returns ------- - sym : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The module that optimizations will be performed on. params : dict of str to tvm.ndarray Dict of converted parameters stored in tvm.ndarray format diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 9b89936de015..efd198803c2b 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -241,7 +241,7 @@ def get_relay_op(op_name): op = None else: # try search op in various modules - for candidate in (_op, _op.nn, _op.image): + for candidate in (_op, _op.nn, _op.image, _op.vision): op = getattr(candidate, op_name, None) if op is not None: break @@ -286,7 +286,7 @@ def clear_padding(self): class AttrCvt(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -300,12 +300,12 @@ class AttrCvt(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in relay. Log warnings. ignores : list diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 653df92b71fc..1cac547d07c9 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -21,6 +21,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from ..._ffi import base as _base @@ -416,8 +417,8 @@ def from_coreml(model, shape=None): Returns ------- - func : tvm.relay.Function - Compatible relay Function. + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by Relay. @@ -463,4 +464,4 @@ def from_coreml(model, shape=None): outexpr = outexpr[0] func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return func, params + return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py new file mode 100644 index 000000000000..7b26ed5692df --- /dev/null +++ b/python/tvm/relay/frontend/darknet.py @@ -0,0 +1,849 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +""" +DarkNet symbol frontend for Relay. +""" + +from __future__ import absolute_import as _abs +from enum import Enum +import numpy as np +import tvm +from .. import ir_pass +from .. import expr as _expr +from .. import module as _module +from .common import get_relay_op, new_var + +__all__ = ['from_darknet'] + +def _darknet_not_support(attr, op='relay'): + """Raise error if any operation is not supported.""" + err = "{} is not supported in {}.".format(attr, op) + raise NotImplementedError(err) + +def _get_params_prefix(opname, layer_num): + """Makes the params prefix name from opname and layer number.""" + return str(opname) + str(layer_num) + +def _get_params_name(prefix, item): + """Makes the params name for the k,v pair.""" + return prefix + '_'+ item + +def _get_param_var(params, prefix, item): + name = _get_params_name(prefix, item) + if name not in params: + raise AttributeError("{} not found in params dict.".format(name)) + return new_var(name, shape=params[name].shape, dtype=params[name].dtype) + +def _darknet_maxpooling(inputs, params, attrs, prefix): + """Process the max pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 1) + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + extra_pad_size = attrs.get('extra_pad_size', 0) + if extra_pad_size: + pad_width = ((0, 0), (0, 0), (0, extra_pad_size), (0, extra_pad_size)) + inputs = [get_relay_op('pad')(*inputs, + pad_width=pad_width, + pad_value=np.finfo(np.float32).min)] + return get_relay_op('max_pool2d')(*inputs, **new_attrs) + +def _darknet_avgpooling(inputs, params, attrs, prefix): + """Process the average pool 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['pool_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + return get_relay_op('avg_pool2d')(*inputs, **new_attrs) + +def _darknet_conv2d(inputs, params, attrs, prefix): + """Process the convolution 2d operation.""" + new_attrs = {} + kernel = attrs.get('kernel') + strides = attrs.get('stride', 1) + pads = attrs.get('pad', 0) + + new_attrs['channels'] = attrs.get('num_filter') + new_attrs['kernel_size'] = (kernel, kernel) + new_attrs['strides'] = (strides, strides) + new_attrs['padding'] = (pads, pads) + new_attrs['dilation'] = attrs.get('dilate', (1, 1)) + new_attrs['groups'] = attrs.get('num_group', 1) + + weight = _get_param_var(params, prefix, 'weight') + out = get_relay_op('conv2d')(*inputs, weight=weight, **new_attrs) + + use_bias = not attrs.get('use_batchNorm', False) + if use_bias: + new_attrs = {} + new_attrs['axis'] = 1 + bias = _get_param_var(params, prefix, 'bias') + out = get_relay_op('bias_add')(out, bias=bias, **new_attrs) + else: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + out = get_relay_op('batch_norm')(out, gamma, beta, moving_mean, moving_var, **new_attrs) + + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + new_attrs['slope'] = 0.1 + out = _darknet_activations(out, None, new_attrs) + return out + +def _darknet_shortcut(inputs, params, attrs, prefix): + """Process the shortcut operation.""" + input_0 = inputs[0] + input_1 = inputs[1] + + input_0_channel = int(attrs['out_channel']) + input_1_channel = int(attrs['add_out_channel']) + input_0_size = int(attrs['out_size']) + input_1_size = int(attrs['add_out_size']) + + if input_0_size > input_1_size: + scale = int(input_0_size/input_1_size) + input_1 = get_relay_op('upsampling')(input_1, scale=scale) + + elif input_0_size < input_1_size: + stride = int(input_1_size/input_0_size) + input_1 = get_relay_op('avg_pool2d')(input_1, + pool_size=(1, 1), + strides=(stride, stride), + padding=(0, 0)) + + if input_0_channel != input_1_channel: + pad_channel = input_0_channel - input_1_channel + input_1 = get_relay_op('pad')(input_1, + pad_width=((0, 0), (0, pad_channel), (0, 0), (0, 0)), + pad_value=0.) + sym = input_0 + input_1 + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + sym = _darknet_activations(sym, None, new_attrs) + return sym + +def _darknet_dense(inputs, params, attrs, prefix): + """Process the dense operation.""" + new_attrs = {} + new_attrs['units'] = attrs.get('num_hidden') + data = inputs[0] + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + weight = _get_param_var(params, prefix, 'weight') + data = get_relay_op('dense')(data, weight, **new_attrs) + + use_bias = attrs.get('use_bias', False) + if use_bias: + bias = _get_param_var(params, prefix, 'bias') + data = get_relay_op('bias_add')(data, bias, axis=1) + + if 'use_batchNorm' in attrs: + new_attrs = {} + new_attrs['epsilon'] = 0.000001 + gamma = _get_param_var(params, prefix, 'gamma') + beta = _get_param_var(params, prefix, 'beta') + moving_mean = _get_param_var(params, prefix, 'moving_mean') + moving_var = _get_param_var(params, prefix, 'moving_var') + data = get_relay_op('batch_norm')(data, gamma, beta, moving_mean, moving_var, **new_attrs) + if 'activation' in attrs: + new_attrs = {} + new_attrs['activation'] = attrs['activation'] + data = _darknet_activations(data, None, new_attrs) + return data + +def _darknet_dropout(inputs, params, attrs, prefix): + """Process the dropout operation, its a blank operation.""" + new_attrs = {} + new_attrs['rate'] = attrs.get('p', 0.5) + return get_relay_op('dropout')(*inputs, **new_attrs) + +def _darknet_reshape(inputs, params, attrs, prefix): + """Process the reshape operation.""" + new_attrs = {} + new_attrs['shape'] = attrs.get('shape') + return get_relay_op('reshape')(*inputs, **new_attrs) + +def _darknet_upsampling(inputs, params, attrs, prefix): + """Process the upsampling operation.""" + new_attrs = {} + new_attrs['scale'] = attrs.get('scale', 1) + return get_relay_op('upsampling')(*inputs, **new_attrs) + +def _darknet_l2normalize(inputs, params, attrs, prefix): + """Process the l2 normalization operation.""" + new_attrs = {} + new_attrs['eps'] = attrs.get('eps', 0.0) + new_attrs['axis'] = [attrs.get('axis', 1)] + return get_relay_op('l2_normalize')(*inputs, **new_attrs) + +def _darknet_softmax_output(inputs, params, attrs, prefix): + """Process the softmax operation.""" + temperature = attrs.get('temperature', 1) + data = inputs[0] + if temperature != 1: + data = data / _expr.const(float(temperature)) + + if attrs.get('use_flatten', False) is True: + data = get_relay_op('batch_flatten')(data) + + new_attrs = {} + if attrs.get('multi_output', False): + new_attrs['axis'] = 1 + return get_relay_op('softmax')(data, **new_attrs) + +def _darknet_route(inputs, params, attrs, prefix): + """Process the route operation, which is equivalent to concat.""" + new_attrs = {'axis': attrs.get('dim', 1)} + return get_relay_op('concatenate')((inputs[0], inputs[1]), **new_attrs) + +def _darknet_reorg(inputs, params, attrs, prefix): + """Process the reorg operation.""" + new_attrs = {} + if 'stride' in attrs: + new_attrs = {'stride': attrs.get('stride', 1)} + return get_relay_op('yolo_reorg')(*inputs, **new_attrs) + +def _darknet_region(inputs, params, attrs, prefix): + """Process the region operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + coords = attrs.get('coords', 0) + background = attrs.get('background', 0) + softmax = attrs.get('softmax', True) + input_shape = attrs.get('shape') + + split_size = classes + coords + 1 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4, 5) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = split_res[2] if background else get_relay_op('sigmoid')(split_res[2]) + split_res3 = get_relay_op('softmax')(split_res[3], axis=2) if softmax else split_res[3] + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2, split_res3), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +def _darknet_yolo(inputs, params, attrs, prefix): + """Process the yolo operation.""" + num = attrs.get('n', 1) + classes = attrs.get('classes', 1) + input_shape = attrs.get('shape') + split_size = classes + 5 + intermediate_shape = (input_shape[0], num, split_size, input_shape[2], input_shape[3]) + data_block = get_relay_op('reshape')(inputs[0], newshape=intermediate_shape) + split_indices = (2, 4) + split_res = get_relay_op('split')(data_block, indices_or_sections=split_indices, axis=2) + split_res0 = get_relay_op('sigmoid')(split_res[0]) + split_res2 = get_relay_op('sigmoid')(split_res[2]) + out = get_relay_op('concatenate')((split_res0, split_res[1], split_res2), axis=2) + return get_relay_op('reshape')(out, newshape=input_shape) + +class ACTIVATION(object): + """Darknet ACTIVATION Class constant.""" + LOGISTIC = 0 + RELU = 1 + RELIE = 2 + LINEAR = 3 + RAMP = 4 + TANH = 5 + PLSE = 6 + LEAKY = 7 + ELU = 8 + LOGGY = 9 + STAIR = 10 + HARDTAN = 11 + LHTAN = 12 + +def _darknet_activations(inputs, params, attrs): + """Process the activation function.""" + act = attrs.get('activation') + data = inputs[0] if isinstance(inputs, _expr.TupleWrapper) else inputs + + def _const(val): + return _expr.const(val) + + def _relu(data): + return get_relay_op('relu')(data) + + def _exp(data): + return get_relay_op('exp')(data) + + def _tanh(data): + return get_relay_op('tanh')(data) + + def _sigmoid(data): + return get_relay_op('sigmoid')(data) + + def _elu(data): + alpha = _const(-1.0) + return alpha * _relu(_const(1.0) - _exp(data)) + _relu(data) + + def _leaky_relu(data, slope): + new_attrs = {} + new_attrs['alpha'] = slope + return get_relay_op('leaky_relu')(data, **new_attrs) + + if ACTIVATION.LOGISTIC == act: + data = _sigmoid(data) + elif ACTIVATION.RELU == act: + data = _relu(data) + elif ACTIVATION.TANH == act: + data = _tanh(data) + elif ACTIVATION.LINEAR == act: + return data + elif ACTIVATION.LEAKY == act: + data = _leaky_relu(data, attrs.get('slope', 0.1)) + elif ACTIVATION.ELU == act: + data = _elu(data) + else: + _darknet_not_support('act: ' + attrs) + return data + +class LAYERTYPE(Enum): + """Darknet LAYERTYPE Class constant.""" + CONVOLUTIONAL = 0 + DECONVOLUTIONAL = 1 + CONNECTED = 2 + MAXPOOL = 3 + SOFTMAX = 4 + DETECTION = 5 + DROPOUT = 6 + CROP = 7 + ROUTE = 8 + COST = 9 + NORMALIZATION = 10 + AVGPOOL = 11 + LOCAL = 12 + SHORTCUT = 13 + ACTIVE = 14 + RNN = 15 + GRU = 16 + LSTM = 17 + CRNN = 18 + BATCHNORM = 19 + NETWORK = 20 + XNOR = 21 + REGION = 22 + YOLO = 23 + REORG = 24 + UPSAMPLE = 25 + LOGXENT = 26 + L2NORM = 27 + BLANK = 28 + +_DARKNET_CONVERT_MAP = { + LAYERTYPE.CONVOLUTIONAL : _darknet_conv2d, + LAYERTYPE.CONNECTED : _darknet_dense, + LAYERTYPE.MAXPOOL : _darknet_maxpooling, + LAYERTYPE.SOFTMAX : _darknet_softmax_output, + LAYERTYPE.DROPOUT : _darknet_dropout, + LAYERTYPE.AVGPOOL : _darknet_avgpooling, + LAYERTYPE.ROUTE : _darknet_route, + LAYERTYPE.REORG : _darknet_reorg, + LAYERTYPE.REGION : _darknet_region, + LAYERTYPE.SHORTCUT : _darknet_shortcut, + LAYERTYPE.UPSAMPLE : _darknet_upsampling, + LAYERTYPE.L2NORM : _darknet_l2normalize, + LAYERTYPE.YOLO : _darknet_yolo, + LAYERTYPE.DECONVOLUTIONAL : _darknet_not_support, + LAYERTYPE.BATCHNORM : _darknet_not_support, + LAYERTYPE.DETECTION : _darknet_not_support, + LAYERTYPE.CROP : _darknet_not_support, + LAYERTYPE.COST : _darknet_not_support, + LAYERTYPE.NORMALIZATION : _darknet_not_support, + LAYERTYPE.LOCAL : _darknet_not_support, + LAYERTYPE.ACTIVE : _darknet_not_support, + LAYERTYPE.RNN : _darknet_not_support, + LAYERTYPE.GRU : _darknet_not_support, + LAYERTYPE.LSTM : _darknet_not_support, + LAYERTYPE.CRNN : _darknet_not_support, + LAYERTYPE.NETWORK : _darknet_not_support, + LAYERTYPE.XNOR : _darknet_not_support, + LAYERTYPE.BLANK : _darknet_not_support, +} + +def _darknet_convert_symbol(op_name, inputs, params, attrs, params_prefix): + """Convert from darknet op to relay op. + Parameters + ---------- + op_name : str + Operator name, such as Convolution, Connected, etc + inputs : list of relay.Function + List of input symbols. + attrs : dict + Dict of operator attributes + params_prefix: str + Params name for this operation + + Returns + ------- + out_name : converted out name of operation + sym : tvm.relay.Function + Converted relay function + """ + + if op_name in _DARKNET_CONVERT_MAP: + sym = _DARKNET_CONVERT_MAP[op_name](inputs, params, attrs, params_prefix) + else: + _darknet_not_support('Operator type ' + str(op_name)) + return sym + +def _as_list(arr): + """Force being a list, ignore if already is.""" + if isinstance(arr, list): + return arr + return [arr] + +class GraphProto(object): + """A helper class for handling relay functions from darknet model. + """ + + def __init__(self, net, shape, dtype='float32'): + self._net = net + self._shape = shape + self._dtype = dtype + self._sym_array = {} + self._tvmparams = {} + self._outs = [] + self._state_ctr = {} + self._state_ctr['rnn'] = 0 + self._state_ctr['crnn'] = 0 + self._state_ctr['lstm'] = 0 + self._state_ctr['cell_state'] = 0 + self._state_ctr['gru'] = 0 + + def _read_memory_buffer(self, shape, data, dtype=None): + if dtype is None: + dtype = self._dtype + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) + + def _get_convolution_weights(self, layer, opname): + """Get the convolution layer weights and biases.""" + if layer.nweights == 0: + return None + + if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: + raise RuntimeError("layer weights size not matching with n c h w") + + params = {} + shape = (layer.n, layer.c, layer.size, layer.size) + weights = self._read_memory_buffer(shape, layer.weights) + + biases = self._read_memory_buffer((layer.n, ), layer.biases) + + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.n)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_connected_weights(self, layer, opname): + """Parse the weights and biases for fully connected or dense layer.""" + size = layer.outputs * layer.inputs + if size == 0: + return None + + weights = self._read_memory_buffer((layer.outputs, layer.inputs), layer.weights) + biases = self._read_memory_buffer((layer.outputs, ), layer.biases) + + params = {} + k = _get_params_name(opname, 'weight') + params[k] = tvm.nd.array(weights) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + params.update(self._get_batchnorm_weights(layer, opname, layer.outputs)) + k = _get_params_name(opname, 'beta') + params[k] = tvm.nd.array(biases) + else: + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + return params + + def _get_region_weights(self, layer, opname): + """Parse the biases for region layer.""" + biases = self._read_memory_buffer((layer.n*2, ), layer.biases) + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.coords, layer.background], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_yolo_weights(self, layer, opname): + """Parse the biases and mask for yolo layer.""" + biases = self._read_memory_buffer((layer.total*2, ), layer.biases) + mask = self._read_memory_buffer((layer.n, ), layer.mask, dtype='int32') + attributes = np.array([layer.n, layer.out_c, layer.out_h, layer.out_w, + layer.classes, layer.total], + dtype=np.int32) + params = {} + k = _get_params_name(opname, 'bias') + params[k] = tvm.nd.array(biases) + k = _get_params_name(opname, 'mask') + params[k] = tvm.nd.array(mask) + k = _get_params_name(opname, 'attr') + params[k] = tvm.nd.array(attributes) + return params + + def _get_batchnorm_weights(self, layer, opname, size): + """Parse the weights for batchnorm, which includes, scales, moving mean + and moving variances.""" + scales = self._read_memory_buffer((size, ), layer.scales) + rolling_mean = self._read_memory_buffer((size, ), layer.rolling_mean) + rolling_variance = self._read_memory_buffer((size, ), layer.rolling_variance) + + params = {} + k = _get_params_name(opname, 'moving_mean') + params[k] = tvm.nd.array(rolling_mean) + k = _get_params_name(opname, 'moving_var') + params[k] = tvm.nd.array(rolling_variance) + k = _get_params_name(opname, 'gamma') + params[k] = tvm.nd.array(scales) + return params + + def _get_darknet_attrs(self, layer, layer_num): + """Parse attributes of each layer and return.""" + attr = {} + use_flatten = True + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.CONVOLUTIONAL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'num_group' : layer.groups}) + attr.update({'num_filter' : layer.n}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + attr.update({'activation' : (layer.activation)}) + + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + elif LAYERTYPE.CONNECTED == layer_type: + attr.update({'num_hidden' : layer.outputs}) + attr.update({'activation' : (layer.activation)}) + if layer_num != 0: + layer_prev = self._net.layers[layer_num - 1] + if (layer_prev.out_h == layer.h and + layer_prev.out_w == layer.w and + layer_prev.out_c == layer.c): + use_flatten = False + attr.update({'use_flatten' : use_flatten}) + attr.update({'use_bias' : True}) + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + attr.update({'use_bias' : False}) + + elif LAYERTYPE.MAXPOOL == layer_type: + attr.update({'pad' : layer.pad}) + attr.update({'stride' : layer.stride}) + attr.update({'kernel' : layer.size}) + max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1 + if max_output < layer.out_w: + extra_pad = (layer.out_w - max_output)*layer.stride + attr.update({'extra_pad_size' : int(extra_pad)}) + elif LAYERTYPE.AVGPOOL == layer_type: + attr.update({'pad' : layer.pad}) + if layer.stride == 0: + attr.update({'stride' : 1}) + else: + attr.update({'stride' : layer.stride}) + if layer.size == 0 and layer.h == layer.w: + attr.update({'kernel' : layer.h}) + else: + attr.update({'kernel' : layer.size}) + + elif LAYERTYPE.DROPOUT == layer_type: + attr.update({'p' : layer.probability}) + + elif LAYERTYPE.SOFTMAX == layer_type: + attr.update({'axis' : 1}) + attr.update({'use_flatten' : True}) + if layer.temperature: + attr.update({'temperature' : str(layer.temperature)}) + + elif LAYERTYPE.SHORTCUT == layer_type: + add_layer = self._net.layers[layer.index] + attr.update({'activation' : layer.activation}) + attr.update({'out_channel' : layer.out_c}) + attr.update({'out_size' : layer.out_h}) + attr.update({'add_out_channel' : add_layer.out_c}) + attr.update({'add_out_size' : add_layer.out_h}) + + elif LAYERTYPE.ROUTE == layer_type: + pass + + elif LAYERTYPE.COST == layer_type: + pass + + elif LAYERTYPE.REORG == layer_type: + attr.update({'stride' : layer.stride}) + + elif LAYERTYPE.REGION == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'coords' : layer.coords}) + attr.update({'background' : layer.background}) + attr.update({'softmax' : layer.softmax}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.YOLO == layer_type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + + elif LAYERTYPE.UPSAMPLE == layer_type: + attr.update({'scale' : layer.stride}) + + elif LAYERTYPE.L2NORM == layer_type: + pass + + else: + err = "Darknet layer type {} is not supported in relay.".format(layer_type) + raise NotImplementedError(err) + + return attr + + def _get_darknet_params(self, layer, opname): + """To parse and get the darknet params.""" + layer_type = LAYERTYPE(layer.type) + params = None + if LAYERTYPE.CONVOLUTIONAL == layer_type: + params = self._get_convolution_weights(layer, opname) + elif LAYERTYPE.CONNECTED == layer_type: + params = self._get_connected_weights(layer, opname) + elif LAYERTYPE.REGION == layer_type: + params = self._get_region_weights(layer, opname) + elif LAYERTYPE.YOLO == layer_type: + params = self._get_yolo_weights(layer, opname) + return params + + def _preproc_layer(self, layer, layer_num): + """To preprocess each darknet layer, some layer doesnt need processing.""" + if layer_num == 0: + name = 'data' + sym = new_var(name, shape=self._shape, dtype=self._dtype) + else: + sym = self._sym_array[layer_num - 1] + skip_layer = False + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.ROUTE == layer_type: + sym = [] + for j in range(layer.n): + sym.append(self._sym_array[layer.input_layers[j]]) + if layer.n == 1: + skip_layer = True + + elif LAYERTYPE.COST == layer_type: + skip_layer = True + + elif LAYERTYPE.SHORTCUT == layer_type: + sym = [sym, self._sym_array[layer.index]] + + elif LAYERTYPE.BLANK == layer_type: + skip_layer = True + + if skip_layer is True: + self._sym_array[layer_num] = sym + + return skip_layer, sym + + def _get_opname(self, layer): + """Returs the layer name.""" + return LAYERTYPE(layer.type) + + def _new_rnn_state_var(self, state=None, name='rnn'): + """Returs a symbol for state""" + sym_name = name + "%d_state" % self._state_ctr[name] + self._state_ctr[name] += 1 + return new_var(sym_name, shape=state.shape, dtype=str(state.dtype)) + + def _get_rnn_state_buffer(self, layer, name): + """Get the state buffer for rnn.""" + buffer = np.zeros((1, layer.outputs), self._dtype) + return self._new_rnn_state_var(buffer, name) + + def _get_darknet_rnn_attrs(self, layer, name, sym): + """Get the rnn converted symbol from attributes.""" + attr = self._get_darknet_attrs(layer, 0) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, name) + params = self._get_darknet_params(layer, prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + if params: + self._tvmparams.update(params) + return sym + + def _handle_darknet_rnn_layers(self, layer_num, sym): + """Parse attributes and handle the rnn layers.""" + attr = {} + layer = self._net.layers[layer_num] + processed = False + + layer_type = LAYERTYPE(layer.type) + if LAYERTYPE.RNN == layer_type: + attr.update({'n' : layer.n}) + attr.update({'batch' : layer.batch}) + attr.update({'num_hidden' : str(layer.outputs)}) + state = self._get_rnn_state_buffer(layer, 'rnn') + for _ in range(layer.steps): + input_layer = layer.input_layer + prefix = "_input_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(input_layer, prefix, sym) + + self_layer = layer.self_layer + prefix = "_self_" + str(layer_num) + state = self._get_darknet_rnn_attrs(self_layer, prefix, state) + + state = sym + state + self._outs.append(state) + + output_layer = layer.output_layer + prefix = "_output_" + str(layer_num) + sym = self._get_darknet_rnn_attrs(output_layer, prefix, state) + + self._sym_array[layer_num] = sym + processed = True + return processed, sym + + def _make_outlist(self, sym, op_name, layer, layer_num): + layer_type = LAYERTYPE(layer.type) + if layer_type == LAYERTYPE.REGION: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + elif layer_type == LAYERTYPE.YOLO: + #Add attributes + k = _get_params_name(op_name, 'attr') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add bias + k = _get_params_name(op_name, 'bias') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + #Add mask + k = _get_params_name(op_name, 'mask') + dshape = self._tvmparams[k].shape + dtype = self._tvmparams[k].dtype + self._outs.insert(0, new_var(k, shape=dshape, dtype=dtype)) + + if layer_num != self._net.n-1: + self._outs.insert(0, sym) + + def from_darknet(self): + """To convert the darknet symbol to relay functions.""" + for i in range(self._net.n): + layer = self._net.layers[i] + need_skip, sym = self._preproc_layer(layer, i) + if need_skip: + continue + + processed, sym = self._handle_darknet_rnn_layers(i, sym) + if processed: + continue + + attr = self._get_darknet_attrs(layer, i) + op_name = self._get_opname(layer) + prefix = _get_params_prefix(op_name, i) + params = self._get_darknet_params(self._net.layers[i], prefix) + sym = _darknet_convert_symbol(op_name, _as_list(sym), params, attr, prefix) + + if params: + self._tvmparams.update(params) + self._sym_array[i] = sym + self._make_outlist(sym, prefix, layer, i) + + outputs = _as_list(sym) + self._outs + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + sym = _expr.Function(ir_pass.free_vars(outputs), outputs) + return _module.Module.from_expr(sym), self._tvmparams + +def from_darknet(net, + shape=None, + dtype="float32"): + """Convert from Darknet's model into compatible relay Function. + + Parameters + ---------- + net : Darknet net parameter + Darknet net structure. + shape : dict of str to tuple, optional + The input shape to the graph + dtype : str or dict of str to str + The input types to the graph + + Returns + ------- + mod : tvm.relay.Module + The relay module for compilation. + + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + + return GraphProto(net, shape, dtype).from_darknet() diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 2648a5a6637b..ad033f9bf326 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -22,6 +22,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import ExprTable, new_var @@ -203,7 +204,6 @@ def _convert_convolution(inexpr, keras_layer, etab): else: kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape weight = weightList[0].transpose([3, 2, 0, 1]) - dilation = [1, 1] if isinstance(keras_layer.dilation_rate, (list, tuple)): dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] else: @@ -680,8 +680,8 @@ def from_keras(model, shape=None): Returns ------- - func : tvm.relay.Function - Compatible relay Function. + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by Relay. @@ -745,4 +745,4 @@ def _convert_input_layer(keras_layer): outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} - return func, params + return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1a4d52f5b679..2f36355abf23 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -23,6 +23,7 @@ from .. import ir_pass from .. import expr as _expr from .. import op as _op +from .. import module as _module from ... import nd as _nd from .common import StrAttrsDict @@ -93,6 +94,15 @@ def impl(inputs, attrs): return impl +def _mx_zeros(inputs, attrs): + assert len(inputs) == 0 + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + if 0 in shape: + return None + return _op.zeros(shape=shape, dtype=dtype) + + def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: @@ -149,7 +159,7 @@ def _mx_conv2d_transpose(inputs, attrs): new_attrs["groups"] = attrs.get_int("num_group", 1) new_attrs["data_layout"] = data_layout new_attrs["kernel_layout"] = kernel_layout - use_bias = not attrs.get_bool("no_bias", False) + use_bias = not attrs.get_bool("no_bias", True) res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs) if use_bias: @@ -277,6 +287,28 @@ def _mx_slice_axis(inputs, attrs): return _op.strided_slice(inputs[0], begin, end) +def _mx_crop_like(inputs, attrs): + if len(inputs) < 2: + raise tvm.error.OpAttributeUnimplemented( + "Only support crop_like pattern for operator Crop.") + if attrs.get_bool("center_crop", False): + raise tvm.error.OpAttributeUnimplemented( + "Center crop is not supported in operator Crop.") + if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0): + raise tvm.error.OpAttributeUnimplemented( + "Doesn't support h_w in operator Crop.") + offset = attrs.get_int_tuple("offset", (0, 0)) + new_attrs = {} + if offset == (0, 0): + new_attrs["axes"] = (2, 3) + return _op.slice_like(*inputs, **new_attrs) + like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape + new_attrs['begin'] = [0, 0, offset[0], offset[1]] + new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], + offset[1]+like_shape[3]] + return _op.strided_slice(inputs[0], **new_attrs) + + def _mx_split(inputs, attrs): axis = attrs.get_int("axis", 1) new_attrs = {} @@ -300,6 +332,10 @@ def _mx_softmax_output(inputs, attrs): return _op.nn.softmax(inputs[0]) +def _mx_linear_regression_output(inputs, _): + return inputs[0] + + def _mx_concat(inputs, attrs): axis = attrs.get_int("dim", 1) return _op.concatenate(tuple(inputs), axis=axis) @@ -543,7 +579,8 @@ def _mx_box_nms(inputs, attrs): raise tvm.error.OpAttributeInvalid( 'Value of attribute "out_format" must equal "corner" for operator box_nms.') - ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) + ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh, + id_index=id_index, score_index=score_index) nms_out = _op.vision.non_max_suppression(ret[1], ret[0], iou_threshold=iou_thresh, @@ -657,6 +694,21 @@ def _mx_argsort(inputs, attrs): return _op.argsort(inputs[0], **new_attrs) +def _mx_topk(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["k"] = attrs.get_int("k", 1) + new_attrs["axis"] = attrs.get_int("axis", -1) + new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True) + ret_type = attrs.get_str("ret_typ", "indices") + if ret_type == "mask": + raise tvm.error.OpAttributeUnimplemented( + "Attribute ret_type=mask is not supported in topk operator") + new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type + new_attrs["dtype"] = attrs.get_str("dtype", "float32") + return _op.topk(inputs[0], **new_attrs) + + def _mx_rnn_param_concat(inputs, _): # We don't need to concatenate RNN params because we will unravel the RNN op return [inputs] @@ -696,13 +748,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): num_layers = attrs.get_int("num_layers", 1) mode = attrs.get_str("mode") + output_states = attrs.get_bool("state_outputs", False) if mode.startswith("rnn"): mode, activation = mode.split('_') assert mode in ["rnn", "gru", "lstm"] bidirectional = attrs.get_bool("bidirectional", False) - if bidirectional: - raise tvm.error.OpAttributeUnimplemented( - "Bidirectional RNN op is not supported yet") + direct = 2 if bidirectional else 1 layout = attrs.get_str("layout", "TNC") if layout != "TNC": raise tvm.error.OpAttributeUnimplemented( @@ -712,44 +763,98 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] - concat_states = inputs[2:] - seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0]) - assert len(concat_weight) == num_layers * 4 + init_states = inputs[2:] + data_shape = ir_pass.infer_type(seq_data).checked_type.shape + seq_len = int(data_shape[0]) + assert len(concat_weight) == num_layers * 4 * direct + + for idx, state in enumerate(init_states[:]): + if isinstance(state, dict): + node = state + attrs = StrAttrsDict(node.get("attrs", {})) + op_name = node["op"] + # by default, RNN layer uses zeros to initialize states + assert op_name == "_zeros" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + init_layout = attrs.get_str("__layout__") + new_shape = list(shape) + for i, dim in enumerate(shape): + if dim == 0: + axis = layout.find(init_layout[i]) + assert axis >= 0 + new_shape[i] = int(data_shape[axis]) + init_states[idx] = _op.zeros(new_shape, dtype) weights = [] bias = [] states = [] + back_weights = [] + back_bias = [] + back_states = [] for i in range(num_layers): - w = [] - b = [] + weights.append([concat_weight[i*2*direct].args[0], + concat_weight[i*2*direct + 1].args[0]]) + bias.append([concat_weight[(num_layers+i)*2*direct].args[0], + concat_weight[(num_layers+i)*2*direct + 1].args[0]]) s = [] - for j in range(2): - w.append(concat_weight[i*2 + j].args[0]) - b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) - for state in concat_states: - s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) - weights.append(w) - bias.append(b) + for state in init_states: + s.append(_op.take(state, _expr.const(i*direct, "int32"), axis=0)) states.append(s) - - seq_output = [] - for t in range(seq_len): - data = _op.take(seq_data, _expr.const(t, "int32"), axis=0) - for l in range(num_layers): + if bidirectional: + back_weights.append([concat_weight[i*2*direct + 2].args[0], + concat_weight[i*2*direct + 3].args[0]]) + back_bias.append([concat_weight[(num_layers+i)*2*direct + 2].args[0], + concat_weight[(num_layers+i)*2*direct + 3].args[0]]) + s = [] + for state in init_states: + s.append(_op.take(state, _expr.const(i*direct+1, "int32"), axis=0)) + back_states.append(s) + + xs = [_op.take(seq_data, _expr.const(t, "int32"), axis=0) for t in range(seq_len)] + for l in range(num_layers): + outputs = [] + back_outputs = [] + for x in xs: if mode == "rnn": - out, new_states = _rnn_cell(data, states[l], *weights[l], *bias[l], activation) + out, new_states = _rnn_cell(x, states[l], *weights[l], *bias[l], activation) elif mode == "gru": - out, new_states = _gru_cell(data, states[l], *weights[l], *bias[l]) + out, new_states = _gru_cell(x, states[l], *weights[l], *bias[l]) else: # mode == "lstm" - out, new_states = _lstm_cell(data, states[l], *weights[l], *bias[l]) + out, new_states = _lstm_cell(x, states[l], *weights[l], *bias[l]) states[l] = new_states - data = out - seq_output.append(out) - - outputs = [_op.stack(seq_output, axis=0)] - for i in range(num_states): - outputs.append(_op.stack([s[i] for s in states], axis=0)) - return outputs + outputs.append(out) + if bidirectional: + for x in reversed(xs): + if mode == "rnn": + out, new_states = _rnn_cell( + x, back_states[l], *back_weights[l], *back_bias[l], activation) + elif mode == "gru": + out, new_states = _gru_cell( + x, back_states[l], *back_weights[l], *back_bias[l]) + else: # mode == "lstm" + out, new_states = _lstm_cell( + x, back_states[l], *back_weights[l], *back_bias[l]) + back_states[l] = new_states + back_outputs.append(out) + back_outputs.reverse() + concat_outputs = [] + for t, out in enumerate(outputs): + new_out = _op.concatenate([out, back_outputs[t]], axis=-1) + concat_outputs.append(new_out) + outputs = concat_outputs + xs = outputs + + ret = [_op.stack(outputs, axis=0)] + if output_states: + for i in range(num_states): + inputs = [] + for l, s in enumerate(states): + inputs.append(s[i]) + if bidirectional: + inputs.append(back_states[l][i]) + ret.append(_op.stack(inputs, axis=0)) + return ret # Note: due to attribute conversion constraint @@ -839,7 +944,6 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "argmin" : _arg_reduce(_op.argmin), # init ops "_ones" : _init_op(_op.ones), - "_zeros" : _init_op(_op.zeros), # softmax "softmax" : _softmax_op(_op.nn.softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax), @@ -853,6 +957,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations + "_zeros" : _mx_zeros, "FullyConnected": _mx_fully_connected, "Activation" : _mx_activations, "Convolution" : _mx_conv2d, @@ -888,8 +993,10 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, "argsort" : _mx_argsort, + "topk" : _mx_topk, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, + "LinearRegressionOutput" : _mx_linear_regression_output, "smooth_l1" : _mx_smooth_l1, # vision "_contrib_BilinearResize2D" : _mx_resize, @@ -905,18 +1012,20 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): # NLP "RNN" : _mx_rnn_layer, "_rnn_param_concat" : _mx_rnn_param_concat, + # Depricated: + "Crop" : _mx_crop_like, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # # "broadcast_to", - # "Crop" : _crop_like, } # set identity list _convert_map.update({k : _rename(k) for k in _identity_list}) -def _from_mxnet_impl(symbol, shape_dict, dtype_info): +def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): + #pylint: disable=unused-argument """Convert mxnet symbol to compatible relay Function. Reconstruct a relay Function by traversing the mxnet symbol. @@ -933,6 +1042,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): dtype_info : dict or str. Known parameter dtypes + mod : tvm.relay.Module + The module that contains global information. It will be used for + converting ops that need global information, e.g. control-flow ops. + Returns: ------- func : tvm.relay.Function @@ -957,7 +1070,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] elif op_name in _convert_map: res = _convert_map[op_name](children, attrs) - if isinstance(res, (_expr.TupleWrapper, tuple, list)): + if res is None: + # defer conversion, used in RNN state initialization + res = [node] + elif isinstance(res, (_expr.TupleWrapper, tuple, list)): pass elif isinstance(res, _expr.Expr): res = [res] @@ -1018,8 +1134,8 @@ def from_mxnet(symbol, Returns ------- - sym : tvm.relay.Function - Compatible relay Function + mod : tvm.relay.Module + The relay module for compilation params : dict of str to tvm.NDArray The parameter dict to be used by nnvm @@ -1029,6 +1145,7 @@ def from_mxnet(symbol, except ImportError as e: raise ImportError("{}. MXNet is required to parse symbols.".format(e)) + mod = _module.Module() if isinstance(symbol, mx.sym.Symbol): params = {} arg_params = arg_params if arg_params else {} @@ -1038,7 +1155,7 @@ def from_mxnet(symbol, for k, v in aux_params.items(): params[k] = _nd.array(v.asnumpy()) shape, dtype = _update_shape_dtype(shape, dtype, params) - sym = _from_mxnet_impl(symbol, shape, dtype) + func = _from_mxnet_impl(symbol, shape, dtype, mod) elif isinstance(symbol, mx.gluon.HybridBlock): if arg_params is not None or aux_params is not None: raise ValueError("arg_params and aux_params ae not used when importing HybridBlock") @@ -1050,10 +1167,11 @@ def from_mxnet(symbol, if isinstance(sym, (list, tuple)): sym = mx.sym.Group(sym) shape, dtype = _update_shape_dtype(shape, dtype, params) - sym = _from_mxnet_impl(sym, shape, dtype) + func = _from_mxnet_impl(sym, shape, dtype, mod) elif isinstance(symbol, mx.gluon.Block): raise NotImplementedError("Only Hybrid Blocks are supported now.") else: msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) raise ValueError(msg) - return sym, params + mod[mod.entry_func] = func + return mod, params diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index eba02e70c865..bb968ec0bea8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -23,7 +23,9 @@ import tvm from ... import nd as _nd from .. import ir_pass +from .. import transform as _transform from .. import expr as _expr +from .. import module as _module from .. import op as _op from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels, get_name @@ -52,6 +54,15 @@ def revert_caffe2_pad(pads): 'Number of pads must be either 2 or 4.') return pads + +def onnx_storage_order2layout(storage_order): + """converter of onnx storage order parameter to tvm storage order format""" + if storage_order not in (0, 1): + raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') + + return 'NCHW' if sotrage_order == 0 else 'NHWC' + + def dimension_constraint(): def _dim_check(attrs): if len(attrs['kernel_shape']) == 2: @@ -60,6 +71,7 @@ def _dim_check(attrs): return _dim_check, "Only 2d kernel supported." + class OnnxOpConverter(object): """ A helper class for holding onnx op converters. """ @@ -108,6 +120,7 @@ def _impl_v1(cls, inputs, attr, params): inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2) return get_relay_op(op_name)(*inputs) + class Pool(OnnxOpConverter): """ A helper class for pool op converters. """ @@ -247,6 +260,7 @@ def _impl_v1(cls, inputs, attr, params): inputs[1], units=channels) return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) + class MatMul(OnnxOpConverter): """ Operator converter for MatMul. """ @@ -257,9 +271,40 @@ def _impl_v1(cls, inputs, attr, params): input_1_t = _op.transpose(inputs[1], axes=(1, 0)) return _op.nn.dense(inputs[0], input_1_t) + class MaxPool(Pool): + """ Operator converter for MaxPool + """ name = 'max_pool' + @classmethod + def _impl_v8(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + # TODO(higumachan): make sure ceil_mode in onnx, and layout? + extras={'ceil_mode': False}, + custom_check=dimension_constraint())(inputs, attr, params) + + @classmethod + def _impl_v10(cls, inputs, attr, params): + return AttrCvt( + op_name=dimension_picker(cls.name), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', (0, 0), revert_caffe2_pad), + 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), + 'ceil_mode': 'ceil_mode' + }, + # very weird attributes here in onnx, force check + ignores=['dilations', 'auto_pad'], + custom_check=dimension_constraint())(inputs, attr, params) class Mul(Elemwise): name = 'multiply' @@ -365,21 +410,27 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - # Try to infer shape by precompute prune if possible. - # TODO: good to check inputs to be in params. - # to be enhanced when relay support list_input_names API of NNVM - logging.warning("Infering Reshape argument by precompute") - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + data, shape = inputs + logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") + shape_params = ir_pass.free_vars(shape) + func = _expr.Function(shape_params, shape) + mod = _module.Module.from_expr(func) + seq = _transform.Sequential([_transform.InferType(), + _transform.FoldConstant(), + _transform.FuseOps(0), + _transform.InferType()]) + with tvm.relay.PassContext(opt_level=2): + mod = seq(mod) with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) + ex = tvm.relay.create_executor("debug", mod=mod) + inputs = [] + for sp in shape_params: + if not sp.name_hint in params: + sh = [int(i) for i in sp.type_annotation.shape] + inputs.append( + tvm.nd.array(np.random.rand(*sh).astype('float32'))) + static_shape = ex.evaluate()(*inputs, **params) + out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) return out @@ -524,6 +575,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # TODO(@jroesch): use shape_of once it has been fixed) return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): @@ -622,6 +674,23 @@ def _impl_v1(cls, inputs, attr, params): extras={'axis':axis})(inputs, {}) #return _op.take(inputs[0], inputs[1], axis) + +class Greater(OnnxOpConverter): + """ Operator logical greater. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.greater(inputs[0], inputs[1]) + + +class Less(OnnxOpConverter): + """ Operator logical less than. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.less(inputs[0], inputs[1]) + + class LRN(OnnxOpConverter): """ Operator converter for Local Response Normalization. """ @@ -836,6 +905,8 @@ def _get_convert_map(opset): 'Selu': Selu.get_converter(opset), 'Elu': Elu.get_converter(opset), 'Exp': Renamer('exp'), + 'Greater': Greater.get_converter(opset), + 'Less': Less.get_converter(opset), 'Log': Renamer('log'), 'Tanh': Renamer('tanh'), 'Pow': Renamer('power'), @@ -915,7 +986,7 @@ def __init__(self, shape, dtype): self._renames = {} self._num_input = 0 self._num_param = 0 - self._shape = shape + self._shape = shape if shape else {} self._dtype = dtype def from_onnx(self, graph, opset): @@ -937,8 +1008,9 @@ def from_onnx(self, graph, opset): Returns ------- - sym : tvm.relay.expr.Function - The returned relay function + mod : tvm.relay.Module + The returned relay module + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -947,6 +1019,9 @@ def from_onnx(self, graph, opset): if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) + self._nodes[init_tensor.name] = new_var(init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' @@ -970,6 +1045,19 @@ def from_onnx(self, graph, opset): else: dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) + # get list of unsupported ops + convert_map = _get_convert_map(opset) + unsupported_ops = set() + for node in graph.node: + op_name = node.op_type + if op_name not in convert_map and \ + op_name != 'Constant' and \ + op_name not in _identity_list: + unsupported_ops.add(op_name) + if unsupported_ops: + msg = 'The following operators are not supported for frontend ONNX: ' + msg += ', '.join(unsupported_ops) + raise tvm.error.OpNotImplemented(msg) # construct nodes, nodes are stored as directed acyclic graph for node in graph.node: op_name = node.op_type @@ -978,8 +1066,15 @@ def from_onnx(self, graph, opset): if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] self._num_param += 1 - self._params[node.output[0]] = self._parse_array(t_proto) - self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) + # We should convert scalar integers to int32, to normalize. + array = self._parse_array(t_proto) + if len(array.shape) == 0 and array.dtype == 'int64': + array = _nd.array(array.asnumpy().astype('int32')) + self._params[node.output[0]] = array + self._nodes[node.output[0]] = new_var( + node.output[0], + shape=list(t_proto.dims), + dtype=array.dtype) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) @@ -1012,7 +1107,7 @@ def from_onnx(self, graph, opset): outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs) - return func, self._params + return _module.Module.from_expr(func), self._params def _parse_value_proto(self, value_proto): """Parse ValueProto or raw str.""" @@ -1076,7 +1171,7 @@ def _convert_operator(self, attrs, opset): """Convert ONNX operator into a Relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters @@ -1141,17 +1236,29 @@ def from_onnx(model, Returns ------- - sym : tvm.relay.expr.Function - Compatible relay function + mod : tvm.relay.Module + The relay module for compilation params : dict of str to tvm.NDArray The parameter dict to be used by relay """ + try: + import onnx + if hasattr(onnx.checker, 'check_model'): + # try use onnx's own model checker before converting any model + try: + onnx.checker.check_model(model) + except onnx.onnx_cpp2py_export.checker.ValidationError as e: + import warnings + # the checker is a bit violent about errors, so simply print warnings here + warnings.warn(str(e)) + except ImportError: + pass g = GraphProto(shape, dtype) graph = model.graph try: opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 - sym, params = g.from_onnx(graph, opset) - return sym, params + mod, params = g.from_onnx(graph, opset) + return mod, params diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 48f78837c525..1b5573121e20 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -31,9 +31,24 @@ from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator +from .. import module as _module __all__ = ['from_tensorflow'] +def _infer_value(input_val, params): + from tvm.contrib import graph_runtime + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in ir_pass.free_vars( + input_val)), "All inputs to infer must be available in params." + func = _expr.Function(ir_pass.free_vars(input_val), input_val) + with tvm.relay.build_config(opt_level=0): + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + return m.get_output(0) + def _get_relay_op(op_name): try: op = getattr(_op, op_name) @@ -49,7 +64,7 @@ def _get_relay_op(op_name): return op class AttrCvt(object): - """Common attribute conveter. An AttrConverter instance is a callable: + """Common attribute converter. An AttrConverter instance is a callable: ``` attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) new_op_name, new_attr = attr_converter(attrs) @@ -63,12 +78,12 @@ class AttrCvt(object): `op_name = func(attr)` transforms : dict of `new_name, or (new_name, default_value, transform function)` If only a new_name is provided, it's like renaming the attribute name. - If default_value if provded, then the attribute is considered as optional. + If default_value if provided, then the attribute is considered as optional. If transform function is provided, the original attribute value is handled by transform function. excludes : list A list of excluded attributes that should `NOT` appear. - Raise NotImplementedError if occured. + Raise NotImplementedError if occurred. disables : list A list of attributes that is disabled in relay. Log warnings. ignores : list @@ -208,17 +223,37 @@ def _dim_check(attrs): return False return _dim_check, "Only 2d kernel supported." -def _infer_channels(inputs, params, transpose=False): - """A hack for getting 'channles' or 'units' since tensorflow don't provide +def _infer_channels(node, params, transpose=False): + """A hack for getting 'channels' or 'units' since tensorflow don't provide these attributes. We check the shape of weights provided to get the number. """ - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - channels = out_shapes[0][0] if not transpose else out_shapes[0][1] + out_shape = _infer_shape(node, params) + channels = out_shape[0] if not transpose else out_shape[1] return channels +def _infer_out_shapes(inputs, params): + """A method to get the output shape of intermediate nodes in the relay graph.""" + return [_infer_shape(inputs, params)] + +def _infer_shape(node, params=None): + """A method to get the output shape of an intermediate node in the relay graph.""" + out_type = ir_pass.infer_type(node) + return get_const_tuple(out_type.checked_type.shape) + +def _get_param(params, input_node): + return params.pop(input_node.name_hint).asnumpy() + +def _get_num_param(params, input_node): + return _get_param(params, input_node)[0] + +def _get_list_param(params, input_node): + return _get_param(params, input_node).tolist() + +def _get_tuple_param(params, input_node): + return tuple(_get_param(params, input_node)) + def _rsqrt(): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) return AttrCvt(op_name="power")(inputs, attr) return _impl @@ -229,16 +264,15 @@ def _impl(inputs, attr, params): try: # In Tensorflow, `axis` argument is a Tensor, not attribute. We # support the case where it inputs from a scalar constant. - axis_input_name = inputs[1].name_hint - axis_input_vlaue = [params[axis_input_name].asnumpy()[0]] + axis_input_value = [_get_num_param(params, inputs[1])] except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) - return func(inputs[0], axis=axis_input_vlaue, keepdims=False) + return func(inputs[0], axis=axis_input_value, keepdims=False) return _impl def _elemwise(name): - def _impl(inputs, attr, *args): + def _impl(inputs, attr, params): assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) return _get_relay_op(name)(*inputs) return _impl @@ -450,6 +484,54 @@ def _impl(inputs, attr, params): return inputs[0] return _impl +def _crop_and_resize(): + def _impl(inputs, attr, params): + # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] + # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] + try: + boxes = params.pop(inputs[1].name_hint).asnumpy().tolist() + box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist() + crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist() + except (IndexError, KeyError): + boxes = _infer_value(inputs[1], params).asnumpy().tolist() + box_ind = _infer_value(inputs[2], params).asnumpy().tolist() + crop_size = _infer_value(inputs[3], params).asnumpy().tolist() + + data_shape = attr['_input_shapes'][inputs[0]] + data_dim = len(data_shape) + method = attr['method'].decode() + + attrs = {} + attrs['size'] = crop_size + attrs['layout'] = 'NHWC' + if method.lower() == 'nearest': + raise tvm.error.OpAttributeUnimplemented( + 'Attribute method=nearest is not supported') + else: + attrs['align_corners'] = True + attrs['method'] = 'BILINEAR' + + out = None + begin = [0] * data_dim + size = data_shape[:] + for idx in box_ind: + # 1) Crop + # y is mapped to the image coordinate at y * (image_height - 1) + # x is mapped to the image coordinate at x * (image_width - 1) + begin[0] = idx + begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1))) + begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1))) + size[0] = idx + 1 + size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1 + size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1 + res_crop = _op.strided_slice(inputs[0], begin=begin, end=size) + + # 2) Resize + res_resize = _get_relay_op('resize')(res_crop, **attrs) + out = _op.concatenate([out, res_resize], axis=0) if out else res_resize + return out + return _impl + def _cast(): def _impl(inputs, attr, params): return inputs[0].astype(attr['DstT'].name) @@ -458,14 +540,19 @@ def _impl(inputs, attr, params): def _expand_dims(): def _impl(inputs, attr, params): dim_input = inputs.pop(1) - axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0] + axis = _get_num_param(params, dim_input) return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'], extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr) return _impl def _resize_bilinear(): def _impl(inputs, attr, params): - attr['size'] = attr['_output_shapes'][0][1:3] + size = attr['_output_shapes'][0][1:3] + # Important that the size is defined. If an axis is not, we need to infer what + # the shape should be. + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size inputs.pop(1) # NHWC attr['layout'] = 'NHWC' @@ -475,6 +562,21 @@ def _impl(inputs, attr, params): extras={'method': "BILINEAR"})(inputs, attr) return _impl +def _resize_nearest_neighbor(): + def _impl(inputs, attr, params): + size = attr['_output_shapes'][0][1:3] + if -1 in size: + size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + attr['size'] = size + inputs.pop(1) + # NHWC + attr['layout'] = 'NHWC' + + return AttrCvt(op_name="resize", + ignores=['Tdim'], + extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr) + return _impl + def _check_numerics(): def _impl(inputs, attr, params): # Making a copy node assuming no need to verify @@ -508,21 +610,19 @@ def _impl(inputs, attr, params): def _concatV2(): def _impl(inputs, attr, params): pop_node = inputs.pop(len(inputs)-1) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['T', 'N', 'Tidx'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _concat(): def _impl(inputs, attr, params): pop_node = inputs.pop(0) - axis = params[pop_node.name_hint] - params.pop(pop_node.name_hint) + axis = int(_get_num_param(params, pop_node)) return AttrCvt( op_name="concatenate", ignores=['N'], - extras={'axis': int(axis.asnumpy()[0])})([inputs], attr) + extras={'axis': axis})([inputs], attr) return _impl def _pack(): @@ -546,8 +646,8 @@ def _impl(inputs, attr, params): def _slice(): def _impl(inputs, attr, params): - begin = params.pop(_get_name_hint(inputs[1])).asnumpy().tolist() - size = params.pop(_get_name_hint(inputs[2])).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + size = _get_list_param(params, inputs[2]) data_shape = attr['_input_shapes'][inputs[0]] data_dim = len(data_shape) end = size @@ -556,43 +656,83 @@ def _impl(inputs, attr, params): end[i] = data_shape[i] - begin[i] else: end[i] += begin[i] - return _op.strided_slice(inputs[0], begin=begin, end=size) + return _op.strided_slice(inputs[0], begin=begin, end=end) return _impl def _reshape(): def _impl(inputs, attr, params): + pop_node = inputs.pop(1) try: - pop_node = inputs[1] - shape_arg = params.pop(pop_node.name_hint) - inputs.pop(1) - - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(shape_arg.asnumpy())}, - ignores=['Tshape'])(inputs, attr) + shape_arg = _get_tuple_param(params, pop_node) except AttributeError: # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) - with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - return AttrCvt( - op_name="reshape", - extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())}, - ignores=['Tshape'])(inputs, attr) + params_new = _infer_value(pop_node, params) + shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + return AttrCvt( + op_name="reshape", + extras={'newshape': shape_arg}, + ignores=['Tshape'])(inputs, attr) + return _impl + + +def _depth_to_space(): + def _impl(inputs, attr, params): + # Need to handle data layouts differently. + input_shape = attr['_input_shapes'][inputs[0]] + block_size = int(attr['block_size']) + if attr['data_format'].decode("utf-8") == 'NHWC': + in_n, in_h, in_w, in_c = input_shape + new_c = int(in_c / (block_size * block_size)) + + # First expand input to larger dimension. + expanded = _op.reshape( + inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c)) + # Now reorder to expand spatial blocks. + transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5)) + # Finally reshape to proper output. + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_h, new_w, new_c) + + else: # Handle NCHW layout + in_n, in_c, in_h, in_w = input_shape + new_c = int(in_c / (block_size * block_size)) + + expanded = _op.reshape( + inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w)) + transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2)) + new_h = in_h * block_size + new_w = in_w * block_size + newshape = (in_n, new_c, new_h, new_w) + + return AttrCvt( + op_name="reshape", + extras={'newshape': newshape}, + ignores=['data_format', 'block_size'])([transposed], attr) + return _impl + def _bias_add(): def _impl(inputs, attr, params): - return _op.add(inputs[0], inputs[1]) + # Must expand for proper broadcasting in NCHW. + if attr['data_format'].decode("utf-8") == 'NCHW': + bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) + else: + bias = inputs[1] + return _op.add(inputs[0], bias) + return _impl + +def _broadcast_to(): + def _impl(inputs, attr, params): + if isinstance(inputs[1], _expr.Var): + shape = params[inputs[1].name_hint] + else: + shape = _infer_value(inputs[1], params) + shape = list(shape.asnumpy().reshape([-1])) + return _op.broadcast_to(inputs[0], shape) return _impl def _squeeze(): @@ -666,9 +806,16 @@ def _impl(inputs, attr, params): def _fill(): def _impl(inputs, attr, params): - fill_arg = params.pop(inputs.pop(1).name_hint) - return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name), - attr['_output_shapes'][0], attr['T'].name) + output_shape = attr['_output_shapes'][0] + # Output shape must be defined to avoid errors. If any axis is not, we must + # try to compute its shape. + if -1 in output_shape: + output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() + + fill_arg = _get_num_param(params, inputs.pop(1)) + dtype = attr['T'].name + return _op.full(tvm.relay.const(fill_arg, dtype), + output_shape, dtype) return _impl def _lrn(): @@ -685,12 +832,21 @@ def _impl(inputs, attr, params): return _impl def _sum(): + def _impl(inputs, attr, params): + axis = _get_tuple_param(params, inputs[1]) + return AttrCvt( + op_name='sum', + extras={'axis': axis}, + transforms={'keep_dims':'keepdims'}, + ignores=['name', 'Tidx'])([inputs[0]], attr) + return _impl + +def _reduce(op): def _impl(inputs, attr, params): axis = params.pop(inputs[1].name_hint).asnumpy() - # convert to tuple for preventing invalid parameter format error axis = tuple(axis) return AttrCvt( - op_name='sum', + op_name=op, extras={'axis': axis}, transforms={'keep_dims':'keepdims'}, ignores=['name', 'Tidx'])([inputs[0]], attr) @@ -704,24 +860,24 @@ def _impl(inputs, attr, params): def _gather(): "GatherV2, Gather" def _impl(inputs, attr, params): - - axis = 0 if len(inputs) > 2: - axis = params[inputs.pop(2).name_hint].asnumpy()[0] - new_input = [] - new_input.append(inputs.pop(0)) - new_input.append(inputs.pop(0)) + axis = _get_num_param(params, inputs.pop(2)) + else: + axis = 0 + new_input = inputs[0:2] return AttrCvt(op_name="take", extras={'axis': tvm.const(axis, 'int32')}, - ignores=['Tindices', 'Tparams', 'validate_indices', \ + ignores=['Tindices', 'Tparams', 'validate_indices', 'Taxis', '_class'])(new_input, attr) return _impl -def _infer_out_shapes(inputs, params): - """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = ir_pass.infer_type(inputs) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - return out_shapes +def _gather_nd(): + """GatherNd""" + def _impl(inputs, attr, params): + return AttrCvt(op_name="gather_nd", + ignores=['Tindices', 'Tparams',\ + 'Taxis', '_class'])(inputs, attr) + return _impl def _stridedSlice(): def _impl(inputs, attr, params): @@ -730,9 +886,9 @@ def _impl(inputs, attr, params): Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/core/util/strided_slice_op.cc#L147-L368 """ - begin = params.pop(inputs[1].name_hint).asnumpy().tolist() - end = params.pop(inputs[2].name_hint).asnumpy().tolist() - stride = params.pop(inputs[3].name_hint).asnumpy().tolist() + begin = _get_list_param(params, inputs[1]) + end = _get_list_param(params, inputs[2]) + stride = _get_list_param(params, inputs[3]) begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) @@ -807,7 +963,7 @@ def _transform_mask(stride_dim, ellipsis_mask): if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_out_shapes(out, params)[0] + out_shape = _infer_shape(out, params) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -828,19 +984,14 @@ def _transform_mask(stride_dim, ellipsis_mask): def _pad(name): def _impl(inputs, attr, params): - padlist_key = inputs[1].name_hint - if padlist_key in params: - padlist = params.pop(padlist_key).asnumpy() - else: - raise tvm.error.OpAttributeRequired( - 'Attribute {} not found in operator Pad.'.format(padlist_key)) - paddings = tuple([tuple(l) for l in padlist]) + padlist = _get_param(params, inputs[1]) + paddings = tuple(tuple(l) for l in padlist) attr['pad_width'] = paddings attr['pad_value'] = 0 new_inputs = [inputs[0]] if name == 'PadV2': - constant_values = params.pop(inputs[2].name_hint).asnumpy() - attr['pad_value'] = constant_values[0] + constant_values = _get_num_param(params, inputs[2]) + attr['pad_value'] = constant_values return AttrCvt( op_name='pad', ignores=['Tpaddings'],)(new_inputs, attr) @@ -850,10 +1001,9 @@ def _transpose(): def _impl(inputs, attr, params): # If perm is not specified, axes is left empty, # otherwise its value is get from params - param_name = _get_name_hint(inputs[1]) - if param_name in params: - axes = tuple(params.get(param_name).asnumpy()) - else: + try: + axes = _get_list_param(params, inputs[1]) + except (IndexError, KeyError): axes = None return _op.transpose(inputs[0], axes=axes) return _impl @@ -863,9 +1013,16 @@ def _impl(inputs, attr, params): return AttrCvt(op_name="where")(inputs, attr) return _impl +def _clip_by_value(): + def _impl(inputs, attr, params): + a_min = params.pop(inputs[1].name_hint).asnumpy()[0] + a_max = params.pop(inputs[2].name_hint).asnumpy()[0] + return _op.clip(inputs[0], a_min=a_min, a_max=a_max) + return _impl + def _reverse_v2(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint).asnumpy()[0] + axis = _get_num_param(params, inputs[1]) return AttrCvt( op_name="reverse", ignores=['Tidx'], @@ -887,37 +1044,42 @@ def _impl(inputs, attr, params): def _range(): def _impl(inputs, attr, params): start = params.pop(inputs[0].name_hint).asnumpy()[0] - limit = params.pop(inputs[1].name_hint).asnumpy()[0] + limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ + if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] delta = params.pop(inputs[2].name_hint).asnumpy()[0] - - name = attr["_node_name"] - params[name] = tvm.nd.array([start, limit, delta]) - return [_expr.var(name, - shape=params[name].shape, - dtype='int32')] + dtype = attr['dtype'].name if 'dtype' in attr else "int32" + return AttrCvt( + op_name="arange", + ignores=['Tidx'], + extras={'start': start, + "stop": limit, + 'step': delta, + 'dtype': dtype})([], attr) return _impl def _elu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.0, attr['T'].name) - return alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.0, dtype) + return alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0]) return _impl def _selu(): def _impl(inputs, attr, params): - alpha = tvm.relay.const(-1.6732632423543772848170429916717, attr['T'].name) - gamma = tvm.relay.const(1.0507009873554804934193349852946, attr['T'].name) - return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, attr['T'].name) \ + dtype = attr['T'].name + alpha = tvm.relay.const(-1.6732632423543772848170429916717, dtype) + gamma = tvm.relay.const(1.0507009873554804934193349852946, dtype) + return gamma * (alpha * _op.nn.relu(tvm.relay.const(1, dtype) \ - _op.exp(inputs[0])) + _op.nn.relu(inputs[0])) return _impl def _mean(): def _impl(inputs, attr, params): - axis = params.pop(inputs[1].name_hint) + axis = _get_tuple_param(params, inputs[1]) return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], transforms={'keep_dims': 'keepdims'}, - extras={'axis': tuple(axis.asnumpy())})([inputs[0]], attr) + extras={'axis': axis})([inputs[0]], attr) return _impl def _broadcast(name): @@ -943,8 +1105,7 @@ def _impl(inputs, attr, params): if has_size_vector: input_node_index = 0 input_axis_index = 2 - size_splits_input_name = _get_name_hint(inputs[1]) - size_splits = params[size_splits_input_name].asnumpy() + size_splits = _get_param(params, inputs[1]) section_beginnings = np.cumsum(size_splits)[:-1] indices_or_sections = tuple(section_beginnings) else: @@ -952,8 +1113,7 @@ def _impl(inputs, attr, params): input_axis_index = 0 indices_or_sections = attr['num_split'] input_node = inputs[input_node_index] - axis_input_name = _get_name_hint(inputs[input_axis_index]) - axis_input_value = params[axis_input_name].asnumpy()[0] + axis_input_value = _get_num_param(params, inputs[input_axis_index]) except (IndexError, KeyError): raise TypeError( \ "Unsupported argument for split: `axis` and `num_or_size_splits` " \ @@ -990,6 +1150,37 @@ def _impl(inputs, attr, params): transforms={'axis': ('axis', 1)})([inputs[0]], attr) return _impl +def _softplus(): + # op description: https://www.tensorflow.org/api_docs/python/tf/math/softplus + def _impl(inputs, attr, params): + exp_out = AttrCvt('exp')(inputs, attr) + inputs.append(tvm.relay.const(1, attr['T'].name)) + rh = tvm.relay.const(1, attr['T'].name) + add_out = _get_relay_op('add')(exp_out, rh) + return _get_relay_op('log')(add_out) + return _impl + +def _topk(): + def _impl(inputs, attr, params): + k = int(params.pop(inputs.pop(1).name_hint).asnumpy()) + if k < 1: + raise tvm.error.OpAttributeInvalid( + 'Attribute k must be positive in operator TopKV2') + if attr['sorted'] is False: + raise tvm.error.OpAttributeUnimplemented( + 'Attribute sorted=False is not supported in operator TopKV2') + return AttrCvt(op_name='topk', + ignores=['sorted'], + extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) + return _impl + +def _floordiv(): + def _impl(inputs, attr, params): + assert len(inputs) == 2 + div = AttrCvt('divide')(inputs, attr) + return _get_relay_op('floor')(div) + return _impl + def _logical(name): def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) @@ -999,8 +1190,8 @@ def _space_to_batch_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + paddings = _get_list_param(params, inputs[2]) N = len(input_shape) M = len(block_shape) batch = input_shape[0] @@ -1021,7 +1212,7 @@ def _impl(inputs, attr, params): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, params) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1038,8 +1229,8 @@ def _batch_to_space_nd(): def _impl(inputs, attr, params): input_node = inputs[0] input_shape = attr['_input_shapes'][input_node] - block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() - crops = params.pop(inputs[2].name_hint).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1]) + crops = _get_list_param(params, inputs[2]) M = len(block_shape) batch = input_shape[0] # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: @@ -1064,7 +1255,7 @@ def _impl(inputs, attr, params): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0] + reshaped_permuted_shape = _infer_shape(reshaped_permuted, params) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -1098,48 +1289,62 @@ def _impl(inputs, attr, params): # for 1 to N mapping(composed), use custom callable functions # for N to 1 mapping, currently not supported(?) _convert_map = { + 'Abs' : AttrCvt('abs'), 'Add' : _elemwise('add'), + 'All' : _reduce('all'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), 'BatchNormWithGlobalNormalization' : _batch_norm(), 'BatchToSpaceND' : _batch_to_space_nd(), 'BiasAdd' : _bias_add(), + 'BroadcastTo' : _broadcast_to(), 'Cast' : _cast(), 'Ceil' : AttrCvt('ceil'), 'CheckNumerics' : _check_numerics(), + 'ClipByValue' : _clip_by_value(), 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), + 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), + 'DepthToSpace' : _depth_to_space(), 'Equal' : _broadcast('equal'), 'Elu' : _elu(), 'Exp' : AttrCvt('exp'), 'ExpandDims' : _expand_dims(), 'Fill' : _fill(), 'Floor' : AttrCvt('floor'), + 'FloorDiv' : _floordiv(), 'FusedBatchNorm' : _fused_batch_norm(), 'FusedBatchNormV2' : _fused_batch_norm(), 'Gather' : _gather(), + 'GatherNd' : _gather_nd(), 'GatherV2' : _gather(), 'Greater' : _broadcast('greater'), 'GreaterEqual' : _broadcast('greater_equal'), 'Identity' : _identity(), 'LeakyRelu' : AttrCvt('leaky_relu'), + 'LeftShift' : AttrCvt('left_shift'), 'Less' : _broadcast('less'), 'LessEqual' : _broadcast('less_equal'), 'Log' : AttrCvt('log'), 'LogicalAnd' : _logical('logical_and'), 'LogicalOr' : _logical('logical_or'), 'LogicalNot' : _logical('logical_not'), + 'LogSoftmax' : AttrCvt('log_softmax'), 'LRN' : _lrn(), 'MatMul' : _matmul(), + 'Max' : _reduce('max'), 'MaxPool' : _pooling('max_pool'), 'Maximum' : _elemwise('maximum'), 'Mean' : _mean(), + 'Min' : _reduce('min'), 'Minimum' : _elemwise('minimum'), + 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), + 'Neg' : AttrCvt('negative'), 'NotEqual' : _broadcast('not_equal'), 'Pack' : _pack(), 'Pad' : _pad('Pad'), @@ -1148,12 +1353,15 @@ def _impl(inputs, attr, params): 'Prod' : _prod(), 'Range' : _range(), 'Rank' : _rank(), - 'RealDiv' : _elemwise('div'), + 'RealDiv' : _elemwise('divide'), 'Relu' : AttrCvt('relu'), 'Relu6' : _relu6(), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), + 'ResizeBicubic' : _resize_bilinear(), + 'ResizeNearestNeighbor' : _resize_nearest_neighbor(), 'ReverseV2' : _reverse_v2(), + 'RightShift' : AttrCvt('right_shift'), 'Round' : AttrCvt('round'), 'Rsqrt' : _rsqrt(), 'Select' : _where(), @@ -1163,9 +1371,11 @@ def _impl(inputs, attr, params): 'Sign' : AttrCvt('sign'), 'Slice' : _slice(), 'Softmax' : _softmax(), + 'Softplus' : _softplus(), 'SpaceToBatchND' : _space_to_batch_nd(), 'Split' : _split(False), 'SplitV' : _split(True), + 'Sqrt' : AttrCvt('sqrt'), 'Square' : _square(), 'Squeeze' : _squeeze(), 'StridedSlice' : _stridedSlice(), @@ -1173,8 +1383,11 @@ def _impl(inputs, attr, params): 'Sum' : _sum(), 'Tanh' : AttrCvt('tanh'), 'Tile' : _tile(), + 'TopKV2' : _topk(), 'Transpose' : _transpose(), + 'TruncateMod' : _elemwise('mod'), 'Unpack' : _unpack(), + 'ZerosLike' : AttrCvt('zeros_like'), } @@ -1458,7 +1671,7 @@ def _in_while_loop(control_flow_node_map, op_name): Parameters ---------- control_flow_node_map : Dict[str, Set[str]] - A dictionay contains the unqiue control flow execution frame name to + A dictionay contains the unique control flow execution frame name to a set of primitive operators mapping. op_name : str @@ -1510,7 +1723,7 @@ def f2(): return tf.add(4, 23) r = tf.cond(tf.less(i, j), f1, f2) - This condition statement should be coverted into Relay in the following + This condition statement should be converted into Relay in the following form: .. code-block:: python @@ -1618,7 +1831,7 @@ def __init__(self): self._loop = None def _while_loop(self): - """An internal API to create a Relay recurisve call for a matched TF + """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ wl = tvm.relay.var('while_loop') @@ -1676,9 +1889,10 @@ def __init__(self): self._input_shapes = {} self._loops = {} self._branches = {} + self._mod = _module.Module({}) def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): - """Construct relay nodes from tensorflow graph definition - GraphDef. + """Construct relay nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to Relay. Some of the assumptions listed below. @@ -1687,7 +1901,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): -> All Const nodes are params. -> Last node is assumed as graph output. -> _output_shapes : Graph should be frozen with add_shapes=True. - Or user can pass input shape dictionaly optionally. + Or user can pass input shape dictionary optionally. -> DecodeJpeg, ResizeBilinear: These are dummy operators. Hence user should handle preprocessing outside. -> CheckNumerics: No implementation as of now for this. @@ -1704,10 +1918,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): shape : Dictionary of input dimensions (Optional) Graph level input shape dictionary. + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- - sym : relay.op - The returned relay operator + mod : tvm.relay.Module + The module that optimizations will be performed on. + params : dict A dict of name: tvm.nd.array pairs, used as pretrained weights """ @@ -1728,7 +1946,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): for node in graph.node: node_name_prefix = node.name.rsplit('/', 1)[0] control_flow_node_map[node_name_prefix].add(node.op) - if node.op == 'Placeholder': + if node.op == 'Placeholder' or node.op == 'PlaceholderWithDefault': # Give priority to user argument. if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1788,7 +2006,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): attr = self._parse_attr(node.attr) - elif node.op != "Placeholder": + elif node.op != "Placeholder" and node.op != 'PlaceholderWithDefault': # Pass the parsed shapes instead attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] @@ -1858,23 +2076,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # Infer shapes even without specifying "add_shapes=True" if output_shapes == [None]: - out_shapes = [] - for node_item in self._nodes[node.name]: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]] self._output_shapes[node.name] = out_shapes if self._output_shapes[node.name] and shape and node.name in shape: assert self._output_shapes[node.name] == list(shape[node.name]) - # Infer shapes if passed explicitely + # Infer shapes if passed explicitly node_output = self._nodes[node.name] if shape and (not self._output_shapes[node.name][0] or -1 in self._output_shapes[node.name][0]): - out_shapes = [] - for node_item in node_output: - out_type = ir_pass.infer_type(node_item) - out_shapes.append(get_const_tuple(out_type.checked_type.shape)) + out_shapes = [_infer_shape(node_item) for node_item in node_output] self._output_shapes[node.name] = out_shapes out = [] @@ -1902,8 +2114,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out = out[0] if len(out) == 1 else _expr.Tuple(out) func = _expr.Function(ir_pass.free_vars(out), out) - - return func, self._params + self._mod[self._mod.entry_func] = func + return self._mod, self._params def _parse_import_prerequisites(self, graph): """ Calculate the named preconditions from TensorFlow `graph`. @@ -1913,7 +2125,7 @@ def _parse_import_prerequisites(self, graph): """ missing_operators = set() for node in graph.node: - if node.op == "Placeholder": + if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass @@ -2133,7 +2345,7 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. - The converter must specify conversions explicity for incompatible name, and + The converter must specify conversions explicitly for incompatible name, and apply handlers to operator attributes. Parameters @@ -2173,7 +2385,7 @@ def _convert_operator(self, op_name, inputs, attrs, def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): - """ Load tensorflow graph which is a python tensorflow graph object into relay. + """Load tensorflow graph which is a python tensorflow graph object into relay. The companion parameters will be handled automatically. Parameters @@ -2181,14 +2393,23 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): graph : GraphDef object Tensorflow GraphDef + layout : target layout to be used (Optional) + NCHW only supported now to enable NHWC models on GPU. + + shape : Dictionary of input dimensions (Optional) + Graph level input shape dictionary. + + outputs : List of output tensor names (Optional) + if not specified then the last node is assumed as graph output. + Returns ------- - sym : relay.op - Compatible relay operator + mod : tvm.relay.Module + The module that optimizations will be performed on. params : dict of str to tvm.ndarray Dict of converted parameters stored in tvm.ndarray format """ g = GraphProto() - sym, params = g.from_tensorflow(graph, layout, shape, outputs) - return sym, params + mod, params = g.from_tensorflow(graph, layout, shape, outputs) + return mod, params diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py index 9cb7eabf0ea5..8105ef03aeca 100644 --- a/python/tvm/relay/frontend/tensorflow_parser.py +++ b/python/tvm/relay/frontend/tensorflow_parser.py @@ -18,7 +18,6 @@ from __future__ import absolute_import as _abs from __future__ import print_function import os -from tensorflow.core.framework import graph_pb2 from tvm.contrib import util @@ -35,12 +34,12 @@ class TFParser(object): -------- .. code-block:: python - parser = TfParser(model_dir) - graph = parser.parse() - # graph is related graphdef of the model + parser = TFParser(model_dir) + graphdef = parser.parse() """ def __init__(self, model_dir): + from tensorflow.core.framework import graph_pb2 self._tmp_dir = util.tempdir() self._model_dir = model_dir self._graph = graph_pb2.GraphDef() @@ -96,6 +95,7 @@ def _load_saved_model(self): from tensorflow.python.tools import freeze_graph from tensorflow.python.framework import ops from tensorflow.python.framework import graph_util + from tensorflow.core.framework import graph_pb2 except ImportError: raise ImportError( "InputConfiguration: Unable to import tensorflow which is " diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ff62d89412e9..fe163871fa60 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -22,6 +22,7 @@ import tvm from .. import ir_pass from .. import expr as _expr +from .. import module as _module from .. import op as _op from ... import nd as _nd from .common import ExprTable @@ -59,12 +60,26 @@ def __init__(self, model, subgraph, exp_tab): 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'RESHAPE': self.convert_reshape, + 'RESIZE_BILINEAR': self.convert_resize_bilinear, + 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'SOFTMAX': self.convert_softmax, 'SQUEEZE': self.convert_squeeze, 'MAX_POOL_2D': self.convert_max_pool2d, 'CONCATENATION': self.convert_concatenation, 'ADD': self.convert_add, + 'SUB': self.convert_sub, + 'MUL': self.convert_mul, + 'DIV': self.convert_div, + 'POW': self.convert_pow, + 'MAXIMUM': self.convert_maximum, + 'MINIMUM': self.convert_minimum, + 'REDUCE_MIN': self._convert_reduce_min, + 'REDUCE_MAX': self._convert_reduce_max, + 'MEAN': self._convert_reduce_mean, + 'REDUCE_PROD': self._convert_reduce_prod, 'FULLY_CONNECTED': self.convert_fully_connected, + 'PAD': self.convert_pad, + 'LOGISTIC': self.convert_logistic, } def check_unsupported_ops(self): @@ -112,7 +127,7 @@ def get_op_code_str(self, op): op_code_str = self.builtin_op_code[op_code_id] if op_code_id == BuiltinOperator.CUSTOM: # Custom operator - raise NotImplementedError("Not Support Custom Operator Now") + raise NotImplementedError("Custom operators are currently not supported") return op_code_str def get_input_tensors(self, op): @@ -155,7 +170,7 @@ def get_tensor_value(self, tensor_wrapper): if tensor_wrapper.tensor.Type() == TensorType.INT32: return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( tensor_wrapper.tensor.ShapeAsNumpy()) - raise NotImplementedError("Not support tensor type {}" + raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) def get_tensor_type_str(self, tensor_type): @@ -171,7 +186,8 @@ def get_tensor_type_str(self, tensor_type): return "float32" if tensor_type == TensorType.INT32: return "int32" - raise NotImplementedError("Not support tensor type {}".format(str(tensor_type))) + raise NotImplementedError("Tensor type {} is currently not supported" + .format(str(tensor_type))) def convert_conv2d(self, op): """Convert TFLite conv2d""" @@ -209,44 +225,79 @@ def convert_reshape(self, op): reshape_options = ReshapeOptions() reshape_options.Init(op_options.Bytes, op_options.Pos) target_shape = reshape_options.NewShapeAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) + out = _op.reshape(in_expr, newshape=tuple(target_shape)) - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Reshape is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) + return out - out = _op.reshape(in_expr, newshape=tuple(target_shape)) + def _convert_resize(self, method, op): + """Generic method to Convert TFLite RESIZE operators""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.ResizeBilinearOptions import ResizeBilinearOptions + # ResizeNearestNeighborOptions was added in tflite v1.13 + tflite_ver = 1120 + if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions): + from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions + tflite_ver = 1130 + except ImportError: + raise ImportError("The tflite package must be installed") - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if len(target_shape) == 1 or len(target_shape) == 2: - pass - elif len(target_shape) == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif len(target_shape) == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - raise tvm.error.OpAttributeInvalid( - 'Length of target shape must be between 1 and 5 for operator Reshape.') + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + # images, 4-D Tensor with shape NHWC. + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # size - 1-D int32 Tensor of 2 elements: new_height, new_width + target_size = tuple(self.get_tensor_value(input_tensors[1])) + + # Options - align_corners (bool) + resize_options = None + align_corners = False + if method == "BILINEAR": + assert op.BuiltinOptionsType() == BuiltinOptions.ResizeBilinearOptions + resize_options = ResizeBilinearOptions() + elif tflite_ver >= 1130: + assert op.BuiltinOptionsType() == BuiltinOptions.ResizeNearestNeighborOptions + resize_options = ResizeNearestNeighborOptions() + + if resize_options is not None: + op_options = op.BuiltinOptions() + resize_options.Init(op_options.Bytes, op_options.Pos) + align_corners = resize_options.AlignCorners() + + # Use layout NHWC + out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) + return out + + def convert_resize_bilinear(self, op): + """Convert TFLite RESIZE_BILINEAR""" + return self._convert_resize("BILINEAR", op) + + def convert_resize_nearest_neighbor(self, op): + """Convert TFLite RESIZE_NEAREST_NEIGHBOR""" + return self._convert_resize("NEAREST_NEIGHBOR", op) + + def convert_logistic(self, op): + """Convert TFLite LOGISTIC""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + out = _op.sigmoid(in_expr) return out def convert_softmax(self, op): @@ -269,7 +320,7 @@ def convert_softmax(self, op): return out def convert_concatenation(self, op): - """ convert TFLite concatenation""" + """Convert TFLite concatenation""" try: from tflite.Operator import Operator from tflite.ConcatenationOptions import ConcatenationOptions @@ -292,15 +343,6 @@ def convert_concatenation(self, op): concatenation_options.Init(op_options.Bytes, op_options.Pos) concatenation_axis = concatenation_options.Axis() fused_activation_fn = concatenation_options.FusedActivationFunction() - input_shape_length = len(input_tensors[0].tensor.ShapeAsNumpy()) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length <= 4: - axis_convert_map = [0] + list(range(2, input_shape_length)) + [1] - concatenation_axis = axis_convert_map[concatenation_axis] - else: - raise NotImplementedError("Not support input shape length {} of concatenatio : " - .format(str(input_shape_length))) # with axis in N H W C out = _op.concatenate(in_exprs, axis=concatenation_axis) @@ -310,10 +352,16 @@ def convert_concatenation(self, op): out = self.convert_fused_activation_function(out, fused_activation_fn) return out - def convert_add(self, op): - """Convert TFLite add""" + def _convert_elemwise(self, relay_op, op): + """Generic method to Convert TFLite elemwise""" try: from tflite.Operator import Operator + from tflite.AddOptions import AddOptions + from tflite.SubOptions import SubOptions + from tflite.MulOptions import MulOptions + from tflite.DivOptions import DivOptions + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -326,33 +374,105 @@ def convert_add(self, op): rhs_tensor = input_tensors[1] if self.has_expr(rhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses ADD operators + # In most cases, we can assume that TOCO fuses elemwise operators # with constants - it means both will be tensors. rhs_expr = self.get_expr(rhs_tensor.tensor_idx) else: - # However, in some corner cases, the ADD operator is not fused, + # However, in some corner cases, the elemwise operator is not fused, # we can receive as constant. rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), dtype=rhs_type_str) + out = relay_op(lhs_expr, rhs_expr) + + # Options (fused_activation_function) + options = None + if op.BuiltinOptionsType() == BuiltinOptions.AddOptions: + options = AddOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.SubOptions: + options = SubOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.MulOptions: + options = MulOptions() + elif op.BuiltinOptionsType() == BuiltinOptions.DivOptions: + options = DivOptions() + + if options is not None: + op_options = op.BuiltinOptions() + options.Init(op_options.Bytes, op_options.Pos) + fused_activation_fn = options.FusedActivationFunction() + # if we have activation fn + if fused_activation_fn != ActivationFunctionType.NONE: + out = self.convert_fused_activation_function(out, fused_activation_fn) - # In this case, we have to be careful about formatting. - input_shape_length = len(rhs_tensor.tensor.ShapeAsNumpy()) - if input_shape_length in (1, 2): - pass - elif input_shape_length == 3: - # N H*W C to N C H*W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # N H W C to N C H W - rhs_expr = _op.transpose(rhs_expr, axes=(0, 3, 1, 2)) - else: - msg = 'Input shape length {} for operator ADD is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) + return out + + def convert_add(self, op): + """Convert TFLite ADD""" + return self._convert_elemwise(_op.add, op) + + def convert_sub(self, op): + """Convert TFLite SUB""" + return self._convert_elemwise(_op.subtract, op) + + def convert_mul(self, op): + """Convert TFLite MUL""" + return self._convert_elemwise(_op.multiply, op) + + def convert_div(self, op): + """Convert TFLite DIV""" + return self._convert_elemwise(_op.divide, op) - out = _op.add(lhs_expr, rhs_expr) + def convert_pow(self, op): + return self._convert_elemwise(_op.power, op) + + def convert_maximum(self, op): + return self._convert_elemwise(_op.maximum, op) + + def convert_minimum(self, op): + return self._convert_elemwise(_op.minimum, op) + + def _convert_reduce(self, relay_op, op): + """Generic method to Convert TFLite MEAN operators""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.ReducerOptions import ReducerOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + # input_tensor + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # axis + axis = tuple(self.get_tensor_value(input_tensors[1])) + + # Options - keep_dims (bool) + assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions + reduce_options = ReducerOptions() + op_options = op.BuiltinOptions() + reduce_options.Init(op_options.Bytes, op_options.Pos) + keep_dims = reduce_options.KeepDims() + + out = relay_op(in_expr, axis, keep_dims) return out + def _convert_reduce_min(self, op): + return self._convert_reduce(_op.reduce.min, op) + + def _convert_reduce_max(self, op): + return self._convert_reduce(_op.reduce.max, op) + + def _convert_reduce_mean(self, op): + return self._convert_reduce(_op.reduce.mean, op) + + def _convert_reduce_prod(self, op): + return self._convert_reduce(_op.reduce.prod, op) + def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: @@ -440,46 +560,10 @@ def convert_squeeze(self, op): squeeze_options = SqueezeOptions() squeeze_options.Init(op_options.Bytes, op_options.Pos) squeeze_axis = squeeze_options.SqueezeDimsAsNumpy() - input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) - output_shape_length = len(output_tensors[0].tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) - - # TFLite is N H W C, our layout is N C H W - if input_shape_length in (1, 2): - # The rule is channel first (after N but before H, W). - # length of 1 means N*H*W*C, do nothing. - # length of 2 means N*H*W, C, do nothing. - pass - elif input_shape_length == 3: - # convert N C H*W to N H*W C - in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) - elif input_shape_length == 4: - # convert input to N H W C, then reshape to target shape, - # finally convert back if necessary - in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) - else: - msg = 'Input shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) - out = _op.squeeze(in_expr, axis=tuple(squeeze_axis)) - # The rule is channel first. - # 1: N*H*W*C - # 2: N*H*W, C - # 3: N H W C, reshape to N H*W C, transpose to N C H*W - # 4: N H W C, transpose to N C H W - # add more if we need target shapes in future - if output_shape_length in (1, 2): - pass - elif output_shape_length == 3: - out = _op.transpose(out, axes=(0, 2, 1)) - elif output_shape_length == 4: - out = _op.transpose(out, axes=(0, 3, 1, 2)) - else: - msg = 'Output shape length {} for operator Squeeze is not valid.' - raise tvm.error.OpAttributeInvalid(msg.format(output_shape_length)) - return out def convert_fused_activation_function(self, in_expr, fused_activation_fn): @@ -535,8 +619,8 @@ def convert_conv(self, op, conv_type): conv_options = DepthwiseConv2DOptions() conv_options.Init(op_options.Bytes, op_options.Pos) depth_multiplier = conv_options.DepthMultiplier() - assert depth_multiplier == 1, "TF frontend have transformed it be 1 " \ - "no matter original value be set by 0.25, 0.5 or any else" + assert depth_multiplier == 1, "TF frontend transforms it to be 1 regardless of what " \ + "original value is set to 0.25, 0.5 or anything else" else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(conv_type)) @@ -562,13 +646,16 @@ def convert_conv(self, op, conv_type): params = {'kernel_size': [kernel_h, kernel_w], 'strides': [stride_h, stride_w], 'dilation': [dilation_h, dilation_w], - 'padding': [0, 0]} + 'padding': [0, 0], + 'data_layout': 'NHWC'} if is_depthwise_conv: params['channels'] = int(in_channels * multiplier) params['groups'] = int(in_channels) + params['kernel_layout'] = 'HWOI' else: params['channels'] = int(output_channels) + params['kernel_layout'] = 'HWIO' # weight tensor type should be UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() @@ -578,12 +665,9 @@ def convert_conv(self, op, conv_type): in_expr = self.get_expr(input_tensor_idx) weight_value = self.get_tensor_value(weight_tensor) - if is_depthwise_conv: - # TFLite is M KH KW IC, we require IC M KH KW - weight_value = weight_value.transpose((3, 0, 1, 2)) - else: - # TFLite is OC KH KW IC, we require OC IC KH kW - weight_value = weight_value.transpose((0, 3, 1, 2)) + # TFLite is OC/M KH KW IC, we require KH KW IC OC/M + # M means multiplier in depthwise convolution + weight_value = weight_value.transpose((1, 2, 3, 0)) weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) @@ -592,9 +676,10 @@ def convert_conv(self, op, conv_type): elif padding == Padding.SAME: pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) - in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (0, 0), + in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (pad_top, pad_bottom), - (pad_left, pad_right))) + (pad_left, pad_right), + (0, 0))) else: raise tvm.error.OpAttributeUnimplemented( 'Padding format {} is not supported for operator Conv.'.format(padding)) @@ -610,7 +695,8 @@ def convert_conv(self, op, conv_type): bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) bias_expr = self.exp_tab.new_const(self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str) - out = _op.nn.bias_add(out, bias_expr) + channel_axis = 3 + out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) # If we have fused activations if fused_activation_fn != ActivationFunctionType.NONE: @@ -648,7 +734,8 @@ def convert_pool2d(self, op, pool_type): params = {'pool_size': (filter_h, filter_w), 'strides': (stride_h, stride_w), - 'padding': [0, 0]} + 'padding': [0, 0], + 'layout': 'NHWC'} in_expr = self.get_expr(input_tensor_idx) @@ -677,6 +764,31 @@ def convert_pool2d(self, op, pool_type): return out + def convert_pad(self, op): + """Convert TFLite PAD""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + # TFLite only support CONSTANT mode and does not support constant_values parameter. + # tensor + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # paddings + pad_list = self.get_tensor_value(input_tensors[1]) + # convert list of lists to tuple of tuples + paddings = tuple(tuple(l) for l in pad_list) + + # Use default pad_value 0 because TFLite does not support constant_values parameter + out = _op.nn.pad(in_expr, paddings) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -764,8 +876,8 @@ def from_tflite(model, shape_dict, dtype_dict): Returns ------- - func : tvm.relay.Function - Compatible relay Function + mod : tvm.relay.Module + The relay module for compilation. params : dict of str to tvm.NDArray The parameter dict to be used by relay @@ -803,4 +915,4 @@ def from_tflite(model, shape_dict, dtype_dict): outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) func = _expr.Function(ir_pass.free_vars(outputs), outputs) - return func, params + return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/grammar/Relay.g4 b/python/tvm/relay/grammar/Relay.g4 index 58546439e1ce..97b4ea24a8b2 100644 --- a/python/tvm/relay/grammar/Relay.g4 +++ b/python/tvm/relay/grammar/Relay.g4 @@ -19,7 +19,7 @@ grammar Relay; -SEMVER: 'v0.0.1' ; +SEMVER: 'v0.0.2' ; // Lexing // comments @@ -111,8 +111,8 @@ expr // | 'debug' # debug ; -func: 'fn' '(' argList ')' ('->' type_)? body ; -defn: 'def' ident '(' argList ')' ('->' type_)? body ; +func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ; +defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ; argList : varList @@ -132,15 +132,20 @@ attr: CNAME '=' expr ; // relations: 'where' relation (',' relation)* ; // relation: ident '(' (type_ (',' type_)*)? ')' ; +typeParamSeq + : '[' ']' + | '[' ident (',' ident)* ']' + ; + type_ : '(' ')' # tupleType | '(' type_ ',' ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType - | identType # identTypeType + | typeIdent # typeIdentType | 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType // currently unused - // | identType '[' (type_ (',' type_)*)? ']' # callType - | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType + // | typeIdent '[' (type_ (',' type_)*)? ']' # callType + | 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | '_' # incompleteType | NAT # intType ; @@ -158,7 +163,7 @@ shape | NAT # intShape ; -identType: CNAME ; +typeIdent : CNAME ; // int8, int16, int32, int64 // uint8, uint16, uint32, uint64 // float16, float32, float64 diff --git a/python/tvm/relay/grammar/py2/.gitattributes b/python/tvm/relay/grammar/py2/.gitattributes new file mode 100644 index 000000000000..4adf65fa2f3c --- /dev/null +++ b/python/tvm/relay/grammar/py2/.gitattributes @@ -0,0 +1,3 @@ +Relay* binary +Relay* linguist-generated=true +Relay* linguist-detectable=false \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/.gitignore b/python/tvm/relay/grammar/py2/.gitignore deleted file mode 100644 index d677ff551940..000000000000 --- a/python/tvm/relay/grammar/py2/.gitignore +++ /dev/null @@ -1 +0,0 @@ -Relay* diff --git a/python/tvm/relay/grammar/py2/Relay.interp b/python/tvm/relay/grammar/py2/Relay.interp new file mode 100644 index 000000000000..c6893d096168 --- /dev/null +++ b/python/tvm/relay/grammar/py2/Relay.interp @@ -0,0 +1,109 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +opIdent +prog +expr +func +defn +argList +varList +var +attrList +attr +typeParamSeq +type_ +shapeSeq +shape +typeIdent +body +scalar +ident + + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 42, 332, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 3, 2, 3, 2, 3, 3, 3, 3, 7, 3, 43, 10, 3, 12, 3, 14, 3, 46, 11, 3, 3, 3, 5, 3, 49, 10, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 6, 4, 72, 10, 4, 13, 4, 14, 4, 73, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 82, 10, 4, 12, 4, 14, 4, 85, 11, 4, 5, 4, 87, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 100, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 110, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 128, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 150, 10, 4, 12, 4, 14, 4, 153, 11, 4, 5, 4, 155, 10, 4, 3, 4, 7, 4, 158, 10, 4, 12, 4, 14, 4, 161, 11, 4, 3, 5, 3, 5, 5, 5, 165, 10, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 5, 5, 172, 10, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 5, 6, 179, 10, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 5, 6, 186, 10, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 5, 7, 196, 10, 7, 3, 8, 3, 8, 3, 8, 7, 8, 201, 10, 8, 12, 8, 14, 8, 204, 11, 8, 5, 8, 206, 10, 8, 3, 9, 3, 9, 3, 9, 5, 9, 211, 10, 9, 3, 10, 3, 10, 3, 10, 7, 10, 216, 10, 10, 12, 10, 14, 10, 219, 11, 10, 5, 10, 221, 10, 10, 3, 11, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 7, 12, 233, 10, 12, 12, 12, 14, 12, 236, 11, 12, 3, 12, 3, 12, 5, 12, 240, 10, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 6, 13, 253, 10, 13, 13, 13, 14, 13, 254, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 269, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 7, 13, 275, 10, 13, 12, 13, 14, 13, 278, 11, 13, 5, 13, 280, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 287, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 6, 14, 300, 10, 14, 13, 14, 14, 14, 301, 3, 14, 3, 14, 5, 14, 306, 10, 14, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 5, 15, 313, 10, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 5, 18, 324, 10, 18, 3, 19, 3, 19, 3, 19, 3, 19, 5, 19, 330, 10, 19, 3, 19, 2, 3, 6, 20, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 2, 6, 3, 2, 25, 26, 3, 2, 27, 28, 3, 2, 29, 32, 3, 2, 33, 34, 2, 373, 2, 38, 3, 2, 2, 2, 4, 40, 3, 2, 2, 2, 6, 127, 3, 2, 2, 2, 8, 162, 3, 2, 2, 2, 10, 175, 3, 2, 2, 2, 12, 195, 3, 2, 2, 2, 14, 205, 3, 2, 2, 2, 16, 207, 3, 2, 2, 2, 18, 220, 3, 2, 2, 2, 20, 222, 3, 2, 2, 2, 22, 239, 3, 2, 2, 2, 24, 286, 3, 2, 2, 2, 26, 305, 3, 2, 2, 2, 28, 312, 3, 2, 2, 2, 30, 314, 3, 2, 2, 2, 32, 316, 3, 2, 2, 2, 34, 323, 3, 2, 2, 2, 36, 329, 3, 2, 2, 2, 38, 39, 7, 42, 2, 2, 39, 3, 3, 2, 2, 2, 40, 48, 7, 21, 2, 2, 41, 43, 5, 10, 6, 2, 42, 41, 3, 2, 2, 2, 43, 46, 3, 2, 2, 2, 44, 42, 3, 2, 2, 2, 44, 45, 3, 2, 2, 2, 45, 49, 3, 2, 2, 2, 46, 44, 3, 2, 2, 2, 47, 49, 5, 6, 4, 2, 48, 44, 3, 2, 2, 2, 48, 47, 3, 2, 2, 2, 49, 50, 3, 2, 2, 2, 50, 51, 7, 2, 2, 3, 51, 5, 3, 2, 2, 2, 52, 53, 8, 4, 1, 2, 53, 54, 7, 3, 2, 2, 54, 55, 5, 6, 4, 2, 55, 56, 7, 4, 2, 2, 56, 128, 3, 2, 2, 2, 57, 58, 7, 28, 2, 2, 58, 128, 5, 6, 4, 19, 59, 128, 5, 8, 5, 2, 60, 61, 7, 3, 2, 2, 61, 128, 7, 4, 2, 2, 62, 63, 7, 3, 2, 2, 63, 64, 5, 6, 4, 2, 64, 65, 7, 5, 2, 2, 65, 66, 7, 4, 2, 2, 66, 128, 3, 2, 2, 2, 67, 68, 7, 3, 2, 2, 68, 71, 5, 6, 4, 2, 69, 70, 7, 5, 2, 2, 70, 72, 5, 6, 4, 2, 71, 69, 3, 2, 2, 2, 72, 73, 3, 2, 2, 2, 73, 71, 3, 2, 2, 2, 73, 74, 3, 2, 2, 2, 74, 75, 3, 2, 2, 2, 75, 76, 7, 4, 2, 2, 76, 128, 3, 2, 2, 2, 77, 86, 7, 6, 2, 2, 78, 83, 5, 6, 4, 2, 79, 80, 7, 5, 2, 2, 80, 82, 5, 6, 4, 2, 81, 79, 3, 2, 2, 2, 82, 85, 3, 2, 2, 2, 83, 81, 3, 2, 2, 2, 83, 84, 3, 2, 2, 2, 84, 87, 3, 2, 2, 2, 85, 83, 3, 2, 2, 2, 86, 78, 3, 2, 2, 2, 86, 87, 3, 2, 2, 2, 87, 88, 3, 2, 2, 2, 88, 128, 7, 7, 2, 2, 89, 90, 7, 8, 2, 2, 90, 91, 7, 3, 2, 2, 91, 92, 5, 6, 4, 2, 92, 93, 7, 4, 2, 2, 93, 94, 5, 32, 17, 2, 94, 95, 7, 9, 2, 2, 95, 96, 5, 32, 17, 2, 96, 128, 3, 2, 2, 2, 97, 99, 7, 10, 2, 2, 98, 100, 7, 38, 2, 2, 99, 98, 3, 2, 2, 2, 99, 100, 3, 2, 2, 2, 100, 101, 3, 2, 2, 2, 101, 102, 5, 16, 9, 2, 102, 103, 7, 11, 2, 2, 103, 104, 5, 6, 4, 2, 104, 105, 7, 12, 2, 2, 105, 106, 5, 6, 4, 8, 106, 128, 3, 2, 2, 2, 107, 109, 7, 10, 2, 2, 108, 110, 7, 38, 2, 2, 109, 108, 3, 2, 2, 2, 109, 110, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 112, 5, 16, 9, 2, 112, 113, 7, 11, 2, 2, 113, 114, 7, 13, 2, 2, 114, 115, 5, 6, 4, 2, 115, 116, 7, 14, 2, 2, 116, 117, 7, 12, 2, 2, 117, 118, 5, 6, 4, 7, 118, 128, 3, 2, 2, 2, 119, 120, 5, 36, 19, 2, 120, 121, 7, 11, 2, 2, 121, 122, 5, 6, 4, 2, 122, 123, 7, 12, 2, 2, 123, 124, 5, 6, 4, 5, 124, 128, 3, 2, 2, 2, 125, 128, 5, 36, 19, 2, 126, 128, 5, 34, 18, 2, 127, 52, 3, 2, 2, 2, 127, 57, 3, 2, 2, 2, 127, 59, 3, 2, 2, 2, 127, 60, 3, 2, 2, 2, 127, 62, 3, 2, 2, 2, 127, 67, 3, 2, 2, 2, 127, 77, 3, 2, 2, 2, 127, 89, 3, 2, 2, 2, 127, 97, 3, 2, 2, 2, 127, 107, 3, 2, 2, 2, 127, 119, 3, 2, 2, 2, 127, 125, 3, 2, 2, 2, 127, 126, 3, 2, 2, 2, 128, 159, 3, 2, 2, 2, 129, 130, 12, 18, 2, 2, 130, 131, 9, 2, 2, 2, 131, 158, 5, 6, 4, 19, 132, 133, 12, 17, 2, 2, 133, 134, 9, 3, 2, 2, 134, 158, 5, 6, 4, 18, 135, 136, 12, 16, 2, 2, 136, 137, 9, 4, 2, 2, 137, 158, 5, 6, 4, 17, 138, 139, 12, 15, 2, 2, 139, 140, 9, 5, 2, 2, 140, 158, 5, 6, 4, 16, 141, 142, 12, 6, 2, 2, 142, 143, 7, 12, 2, 2, 143, 158, 5, 6, 4, 7, 144, 145, 12, 20, 2, 2, 145, 154, 7, 3, 2, 2, 146, 151, 5, 6, 4, 2, 147, 148, 7, 5, 2, 2, 148, 150, 5, 6, 4, 2, 149, 147, 3, 2, 2, 2, 150, 153, 3, 2, 2, 2, 151, 149, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 155, 3, 2, 2, 2, 153, 151, 3, 2, 2, 2, 154, 146, 3, 2, 2, 2, 154, 155, 3, 2, 2, 2, 155, 156, 3, 2, 2, 2, 156, 158, 7, 4, 2, 2, 157, 129, 3, 2, 2, 2, 157, 132, 3, 2, 2, 2, 157, 135, 3, 2, 2, 2, 157, 138, 3, 2, 2, 2, 157, 141, 3, 2, 2, 2, 157, 144, 3, 2, 2, 2, 158, 161, 3, 2, 2, 2, 159, 157, 3, 2, 2, 2, 159, 160, 3, 2, 2, 2, 160, 7, 3, 2, 2, 2, 161, 159, 3, 2, 2, 2, 162, 164, 7, 15, 2, 2, 163, 165, 5, 22, 12, 2, 164, 163, 3, 2, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 3, 2, 2, 2, 166, 167, 7, 3, 2, 2, 167, 168, 5, 12, 7, 2, 168, 171, 7, 4, 2, 2, 169, 170, 7, 16, 2, 2, 170, 172, 5, 24, 13, 2, 171, 169, 3, 2, 2, 2, 171, 172, 3, 2, 2, 2, 172, 173, 3, 2, 2, 2, 173, 174, 5, 32, 17, 2, 174, 9, 3, 2, 2, 2, 175, 176, 7, 17, 2, 2, 176, 178, 5, 36, 19, 2, 177, 179, 5, 22, 12, 2, 178, 177, 3, 2, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 3, 2, 2, 2, 180, 181, 7, 3, 2, 2, 181, 182, 5, 12, 7, 2, 182, 185, 7, 4, 2, 2, 183, 184, 7, 16, 2, 2, 184, 186, 5, 24, 13, 2, 185, 183, 3, 2, 2, 2, 185, 186, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 188, 5, 32, 17, 2, 188, 11, 3, 2, 2, 2, 189, 196, 5, 14, 8, 2, 190, 196, 5, 18, 10, 2, 191, 192, 5, 14, 8, 2, 192, 193, 7, 5, 2, 2, 193, 194, 5, 18, 10, 2, 194, 196, 3, 2, 2, 2, 195, 189, 3, 2, 2, 2, 195, 190, 3, 2, 2, 2, 195, 191, 3, 2, 2, 2, 196, 13, 3, 2, 2, 2, 197, 202, 5, 16, 9, 2, 198, 199, 7, 5, 2, 2, 199, 201, 5, 16, 9, 2, 200, 198, 3, 2, 2, 2, 201, 204, 3, 2, 2, 2, 202, 200, 3, 2, 2, 2, 202, 203, 3, 2, 2, 2, 203, 206, 3, 2, 2, 2, 204, 202, 3, 2, 2, 2, 205, 197, 3, 2, 2, 2, 205, 206, 3, 2, 2, 2, 206, 15, 3, 2, 2, 2, 207, 210, 5, 36, 19, 2, 208, 209, 7, 18, 2, 2, 209, 211, 5, 24, 13, 2, 210, 208, 3, 2, 2, 2, 210, 211, 3, 2, 2, 2, 211, 17, 3, 2, 2, 2, 212, 217, 5, 20, 11, 2, 213, 214, 7, 5, 2, 2, 214, 216, 5, 20, 11, 2, 215, 213, 3, 2, 2, 2, 216, 219, 3, 2, 2, 2, 217, 215, 3, 2, 2, 2, 217, 218, 3, 2, 2, 2, 218, 221, 3, 2, 2, 2, 219, 217, 3, 2, 2, 2, 220, 212, 3, 2, 2, 2, 220, 221, 3, 2, 2, 2, 221, 19, 3, 2, 2, 2, 222, 223, 7, 42, 2, 2, 223, 224, 7, 11, 2, 2, 224, 225, 5, 6, 4, 2, 225, 21, 3, 2, 2, 2, 226, 227, 7, 6, 2, 2, 227, 240, 7, 7, 2, 2, 228, 229, 7, 6, 2, 2, 229, 234, 5, 36, 19, 2, 230, 231, 7, 5, 2, 2, 231, 233, 5, 36, 19, 2, 232, 230, 3, 2, 2, 2, 233, 236, 3, 2, 2, 2, 234, 232, 3, 2, 2, 2, 234, 235, 3, 2, 2, 2, 235, 237, 3, 2, 2, 2, 236, 234, 3, 2, 2, 2, 237, 238, 7, 7, 2, 2, 238, 240, 3, 2, 2, 2, 239, 226, 3, 2, 2, 2, 239, 228, 3, 2, 2, 2, 240, 23, 3, 2, 2, 2, 241, 242, 7, 3, 2, 2, 242, 287, 7, 4, 2, 2, 243, 244, 7, 3, 2, 2, 244, 245, 5, 24, 13, 2, 245, 246, 7, 5, 2, 2, 246, 247, 7, 4, 2, 2, 247, 287, 3, 2, 2, 2, 248, 249, 7, 3, 2, 2, 249, 252, 5, 24, 13, 2, 250, 251, 7, 5, 2, 2, 251, 253, 5, 24, 13, 2, 252, 250, 3, 2, 2, 2, 253, 254, 3, 2, 2, 2, 254, 252, 3, 2, 2, 2, 254, 255, 3, 2, 2, 2, 255, 256, 3, 2, 2, 2, 256, 257, 7, 4, 2, 2, 257, 287, 3, 2, 2, 2, 258, 287, 5, 30, 16, 2, 259, 260, 7, 19, 2, 2, 260, 261, 7, 6, 2, 2, 261, 262, 5, 26, 14, 2, 262, 263, 7, 5, 2, 2, 263, 264, 5, 24, 13, 2, 264, 265, 7, 7, 2, 2, 265, 287, 3, 2, 2, 2, 266, 268, 7, 15, 2, 2, 267, 269, 5, 22, 12, 2, 268, 267, 3, 2, 2, 2, 268, 269, 3, 2, 2, 2, 269, 270, 3, 2, 2, 2, 270, 279, 7, 3, 2, 2, 271, 276, 5, 24, 13, 2, 272, 273, 7, 5, 2, 2, 273, 275, 5, 24, 13, 2, 274, 272, 3, 2, 2, 2, 275, 278, 3, 2, 2, 2, 276, 274, 3, 2, 2, 2, 276, 277, 3, 2, 2, 2, 277, 280, 3, 2, 2, 2, 278, 276, 3, 2, 2, 2, 279, 271, 3, 2, 2, 2, 279, 280, 3, 2, 2, 2, 280, 281, 3, 2, 2, 2, 281, 282, 7, 4, 2, 2, 282, 283, 7, 16, 2, 2, 283, 287, 5, 24, 13, 2, 284, 287, 7, 20, 2, 2, 285, 287, 7, 41, 2, 2, 286, 241, 3, 2, 2, 2, 286, 243, 3, 2, 2, 2, 286, 248, 3, 2, 2, 2, 286, 258, 3, 2, 2, 2, 286, 259, 3, 2, 2, 2, 286, 266, 3, 2, 2, 2, 286, 284, 3, 2, 2, 2, 286, 285, 3, 2, 2, 2, 287, 25, 3, 2, 2, 2, 288, 289, 7, 3, 2, 2, 289, 306, 7, 4, 2, 2, 290, 291, 7, 3, 2, 2, 291, 292, 5, 28, 15, 2, 292, 293, 7, 5, 2, 2, 293, 294, 7, 4, 2, 2, 294, 306, 3, 2, 2, 2, 295, 296, 7, 3, 2, 2, 296, 299, 5, 28, 15, 2, 297, 298, 7, 5, 2, 2, 298, 300, 5, 28, 15, 2, 299, 297, 3, 2, 2, 2, 300, 301, 3, 2, 2, 2, 301, 299, 3, 2, 2, 2, 301, 302, 3, 2, 2, 2, 302, 303, 3, 2, 2, 2, 303, 304, 7, 4, 2, 2, 304, 306, 3, 2, 2, 2, 305, 288, 3, 2, 2, 2, 305, 290, 3, 2, 2, 2, 305, 295, 3, 2, 2, 2, 306, 27, 3, 2, 2, 2, 307, 308, 7, 3, 2, 2, 308, 309, 5, 28, 15, 2, 309, 310, 7, 4, 2, 2, 310, 313, 3, 2, 2, 2, 311, 313, 7, 41, 2, 2, 312, 307, 3, 2, 2, 2, 312, 311, 3, 2, 2, 2, 313, 29, 3, 2, 2, 2, 314, 315, 7, 42, 2, 2, 315, 31, 3, 2, 2, 2, 316, 317, 7, 13, 2, 2, 317, 318, 5, 6, 4, 2, 318, 319, 7, 14, 2, 2, 319, 33, 3, 2, 2, 2, 320, 324, 7, 40, 2, 2, 321, 324, 7, 41, 2, 2, 322, 324, 7, 39, 2, 2, 323, 320, 3, 2, 2, 2, 323, 321, 3, 2, 2, 2, 323, 322, 3, 2, 2, 2, 324, 35, 3, 2, 2, 2, 325, 330, 5, 2, 2, 2, 326, 330, 7, 35, 2, 2, 327, 330, 7, 36, 2, 2, 328, 330, 7, 37, 2, 2, 329, 325, 3, 2, 2, 2, 329, 326, 3, 2, 2, 2, 329, 327, 3, 2, 2, 2, 329, 328, 3, 2, 2, 2, 330, 37, 3, 2, 2, 2, 36, 44, 48, 73, 83, 86, 99, 109, 127, 151, 154, 157, 159, 164, 171, 178, 185, 195, 202, 205, 210, 217, 220, 234, 239, 254, 268, 276, 279, 286, 301, 305, 312, 323, 329] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/Relay.tokens b/python/tvm/relay/grammar/py2/Relay.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py2/Relay.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py2/RelayLexer.interp b/python/tvm/relay/grammar/py2/RelayLexer.interp new file mode 100644 index 000000000000..092b3589ab70 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.interp @@ -0,0 +1,140 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +T__0 +T__1 +T__2 +T__3 +T__4 +T__5 +T__6 +T__7 +T__8 +T__9 +T__10 +T__11 +T__12 +T__13 +T__14 +T__15 +T__16 +T__17 +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +EXP +CNAME +LETTER +DIGIT + +channel names: +DEFAULT_TOKEN_CHANNEL +HIDDEN + +mode names: +DEFAULT_MODE + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 42, 267, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 6, 21, 149, 10, 21, 13, 21, 14, 21, 150, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 22, 7, 22, 159, 10, 22, 12, 22, 14, 22, 162, 11, 22, 3, 22, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 23, 3, 23, 7, 23, 172, 10, 23, 12, 23, 14, 23, 175, 11, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 34, 3, 34, 3, 34, 3, 35, 3, 35, 3, 35, 3, 36, 3, 36, 3, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 228, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 234, 10, 39, 3, 39, 3, 39, 3, 39, 5, 39, 239, 10, 39, 3, 40, 6, 40, 242, 10, 40, 13, 40, 14, 40, 243, 3, 41, 3, 41, 5, 41, 248, 10, 41, 3, 41, 3, 41, 3, 42, 3, 42, 5, 42, 254, 10, 42, 3, 42, 3, 42, 3, 42, 7, 42, 259, 10, 42, 12, 42, 14, 42, 262, 11, 42, 3, 43, 3, 43, 3, 44, 3, 44, 4, 160, 173, 2, 45, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 2, 83, 42, 85, 2, 87, 2, 3, 2, 7, 5, 2, 11, 12, 15, 15, 34, 34, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 67, 92, 99, 124, 3, 2, 50, 59, 2, 275, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 3, 89, 3, 2, 2, 2, 5, 91, 3, 2, 2, 2, 7, 93, 3, 2, 2, 2, 9, 95, 3, 2, 2, 2, 11, 97, 3, 2, 2, 2, 13, 99, 3, 2, 2, 2, 15, 102, 3, 2, 2, 2, 17, 107, 3, 2, 2, 2, 19, 111, 3, 2, 2, 2, 21, 113, 3, 2, 2, 2, 23, 115, 3, 2, 2, 2, 25, 117, 3, 2, 2, 2, 27, 119, 3, 2, 2, 2, 29, 122, 3, 2, 2, 2, 31, 125, 3, 2, 2, 2, 33, 129, 3, 2, 2, 2, 35, 131, 3, 2, 2, 2, 37, 138, 3, 2, 2, 2, 39, 140, 3, 2, 2, 2, 41, 148, 3, 2, 2, 2, 43, 154, 3, 2, 2, 2, 45, 167, 3, 2, 2, 2, 47, 181, 3, 2, 2, 2, 49, 183, 3, 2, 2, 2, 51, 185, 3, 2, 2, 2, 53, 187, 3, 2, 2, 2, 55, 189, 3, 2, 2, 2, 57, 191, 3, 2, 2, 2, 59, 193, 3, 2, 2, 2, 61, 196, 3, 2, 2, 2, 63, 199, 3, 2, 2, 2, 65, 202, 3, 2, 2, 2, 67, 205, 3, 2, 2, 2, 69, 208, 3, 2, 2, 2, 71, 211, 3, 2, 2, 2, 73, 214, 3, 2, 2, 2, 75, 227, 3, 2, 2, 2, 77, 238, 3, 2, 2, 2, 79, 241, 3, 2, 2, 2, 81, 245, 3, 2, 2, 2, 83, 253, 3, 2, 2, 2, 85, 263, 3, 2, 2, 2, 87, 265, 3, 2, 2, 2, 89, 90, 7, 42, 2, 2, 90, 4, 3, 2, 2, 2, 91, 92, 7, 43, 2, 2, 92, 6, 3, 2, 2, 2, 93, 94, 7, 46, 2, 2, 94, 8, 3, 2, 2, 2, 95, 96, 7, 93, 2, 2, 96, 10, 3, 2, 2, 2, 97, 98, 7, 95, 2, 2, 98, 12, 3, 2, 2, 2, 99, 100, 7, 107, 2, 2, 100, 101, 7, 104, 2, 2, 101, 14, 3, 2, 2, 2, 102, 103, 7, 103, 2, 2, 103, 104, 7, 110, 2, 2, 104, 105, 7, 117, 2, 2, 105, 106, 7, 103, 2, 2, 106, 16, 3, 2, 2, 2, 107, 108, 7, 110, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 18, 3, 2, 2, 2, 111, 112, 7, 63, 2, 2, 112, 20, 3, 2, 2, 2, 113, 114, 7, 61, 2, 2, 114, 22, 3, 2, 2, 2, 115, 116, 7, 125, 2, 2, 116, 24, 3, 2, 2, 2, 117, 118, 7, 127, 2, 2, 118, 26, 3, 2, 2, 2, 119, 120, 7, 104, 2, 2, 120, 121, 7, 112, 2, 2, 121, 28, 3, 2, 2, 2, 122, 123, 7, 47, 2, 2, 123, 124, 7, 64, 2, 2, 124, 30, 3, 2, 2, 2, 125, 126, 7, 102, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 104, 2, 2, 128, 32, 3, 2, 2, 2, 129, 130, 7, 60, 2, 2, 130, 34, 3, 2, 2, 2, 131, 132, 7, 86, 2, 2, 132, 133, 7, 103, 2, 2, 133, 134, 7, 112, 2, 2, 134, 135, 7, 117, 2, 2, 135, 136, 7, 113, 2, 2, 136, 137, 7, 116, 2, 2, 137, 36, 3, 2, 2, 2, 138, 139, 7, 97, 2, 2, 139, 38, 3, 2, 2, 2, 140, 141, 7, 120, 2, 2, 141, 142, 7, 50, 2, 2, 142, 143, 7, 48, 2, 2, 143, 144, 7, 50, 2, 2, 144, 145, 7, 48, 2, 2, 145, 146, 7, 52, 2, 2, 146, 40, 3, 2, 2, 2, 147, 149, 9, 2, 2, 2, 148, 147, 3, 2, 2, 2, 149, 150, 3, 2, 2, 2, 150, 148, 3, 2, 2, 2, 150, 151, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 153, 8, 21, 2, 2, 153, 42, 3, 2, 2, 2, 154, 155, 7, 49, 2, 2, 155, 156, 7, 49, 2, 2, 156, 160, 3, 2, 2, 2, 157, 159, 11, 2, 2, 2, 158, 157, 3, 2, 2, 2, 159, 162, 3, 2, 2, 2, 160, 161, 3, 2, 2, 2, 160, 158, 3, 2, 2, 2, 161, 163, 3, 2, 2, 2, 162, 160, 3, 2, 2, 2, 163, 164, 7, 12, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 8, 22, 2, 2, 166, 44, 3, 2, 2, 2, 167, 168, 7, 49, 2, 2, 168, 169, 7, 44, 2, 2, 169, 173, 3, 2, 2, 2, 170, 172, 11, 2, 2, 2, 171, 170, 3, 2, 2, 2, 172, 175, 3, 2, 2, 2, 173, 174, 3, 2, 2, 2, 173, 171, 3, 2, 2, 2, 174, 176, 3, 2, 2, 2, 175, 173, 3, 2, 2, 2, 176, 177, 7, 44, 2, 2, 177, 178, 7, 49, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 8, 23, 2, 2, 180, 46, 3, 2, 2, 2, 181, 182, 7, 44, 2, 2, 182, 48, 3, 2, 2, 2, 183, 184, 7, 49, 2, 2, 184, 50, 3, 2, 2, 2, 185, 186, 7, 45, 2, 2, 186, 52, 3, 2, 2, 2, 187, 188, 7, 47, 2, 2, 188, 54, 3, 2, 2, 2, 189, 190, 7, 62, 2, 2, 190, 56, 3, 2, 2, 2, 191, 192, 7, 64, 2, 2, 192, 58, 3, 2, 2, 2, 193, 194, 7, 62, 2, 2, 194, 195, 7, 63, 2, 2, 195, 60, 3, 2, 2, 2, 196, 197, 7, 64, 2, 2, 197, 198, 7, 63, 2, 2, 198, 62, 3, 2, 2, 2, 199, 200, 7, 63, 2, 2, 200, 201, 7, 63, 2, 2, 201, 64, 3, 2, 2, 2, 202, 203, 7, 35, 2, 2, 203, 204, 7, 63, 2, 2, 204, 66, 3, 2, 2, 2, 205, 206, 7, 66, 2, 2, 206, 207, 5, 83, 42, 2, 207, 68, 3, 2, 2, 2, 208, 209, 7, 39, 2, 2, 209, 210, 5, 83, 42, 2, 210, 70, 3, 2, 2, 2, 211, 212, 7, 39, 2, 2, 212, 213, 5, 79, 40, 2, 213, 72, 3, 2, 2, 2, 214, 215, 7, 111, 2, 2, 215, 216, 7, 119, 2, 2, 216, 217, 7, 118, 2, 2, 217, 74, 3, 2, 2, 2, 218, 219, 7, 86, 2, 2, 219, 220, 7, 116, 2, 2, 220, 221, 7, 119, 2, 2, 221, 228, 7, 103, 2, 2, 222, 223, 7, 72, 2, 2, 223, 224, 7, 99, 2, 2, 224, 225, 7, 110, 2, 2, 225, 226, 7, 117, 2, 2, 226, 228, 7, 103, 2, 2, 227, 218, 3, 2, 2, 2, 227, 222, 3, 2, 2, 2, 228, 76, 3, 2, 2, 2, 229, 230, 5, 79, 40, 2, 230, 231, 7, 48, 2, 2, 231, 233, 5, 79, 40, 2, 232, 234, 5, 81, 41, 2, 233, 232, 3, 2, 2, 2, 233, 234, 3, 2, 2, 2, 234, 239, 3, 2, 2, 2, 235, 236, 5, 79, 40, 2, 236, 237, 5, 81, 41, 2, 237, 239, 3, 2, 2, 2, 238, 229, 3, 2, 2, 2, 238, 235, 3, 2, 2, 2, 239, 78, 3, 2, 2, 2, 240, 242, 5, 87, 44, 2, 241, 240, 3, 2, 2, 2, 242, 243, 3, 2, 2, 2, 243, 241, 3, 2, 2, 2, 243, 244, 3, 2, 2, 2, 244, 80, 3, 2, 2, 2, 245, 247, 9, 3, 2, 2, 246, 248, 9, 4, 2, 2, 247, 246, 3, 2, 2, 2, 247, 248, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 250, 5, 79, 40, 2, 250, 82, 3, 2, 2, 2, 251, 254, 7, 97, 2, 2, 252, 254, 5, 85, 43, 2, 253, 251, 3, 2, 2, 2, 253, 252, 3, 2, 2, 2, 254, 260, 3, 2, 2, 2, 255, 259, 7, 97, 2, 2, 256, 259, 5, 85, 43, 2, 257, 259, 5, 87, 44, 2, 258, 255, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 258, 257, 3, 2, 2, 2, 259, 262, 3, 2, 2, 2, 260, 258, 3, 2, 2, 2, 260, 261, 3, 2, 2, 2, 261, 84, 3, 2, 2, 2, 262, 260, 3, 2, 2, 2, 263, 264, 9, 5, 2, 2, 264, 86, 3, 2, 2, 2, 265, 266, 9, 6, 2, 2, 266, 88, 3, 2, 2, 2, 14, 2, 150, 160, 173, 227, 233, 238, 243, 247, 253, 258, 260, 3, 8, 2, 2] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py2/RelayLexer.py b/python/tvm/relay/grammar/py2/RelayLexer.py new file mode 100644 index 000000000000..be87421c2da6 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.py @@ -0,0 +1,209 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from __future__ import print_function +from antlr4 import * +from io import StringIO +import sys + + + +def serializedATN(): + with StringIO() as buf: + buf.write(u"\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2") + buf.write(u"*\u010b\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4") + buf.write(u"\7\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r") + buf.write(u"\t\r\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22") + buf.write(u"\4\23\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4") + buf.write(u"\30\t\30\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35") + buf.write(u"\t\35\4\36\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4") + buf.write(u"$\t$\4%\t%\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t") + buf.write(u",\3\2\3\2\3\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7") + buf.write(u"\3\b\3\b\3\b\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13") + buf.write(u"\3\f\3\f\3\r\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3") + buf.write(u"\20\3\20\3\20\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22") + buf.write(u"\3\22\3\23\3\23\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3") + buf.write(u"\25\6\25\u0095\n\25\r\25\16\25\u0096\3\25\3\25\3\26\3") + buf.write(u"\26\3\26\3\26\7\26\u009f\n\26\f\26\16\26\u00a2\13\26") + buf.write(u"\3\26\3\26\3\26\3\26\3\27\3\27\3\27\3\27\7\27\u00ac\n") + buf.write(u"\27\f\27\16\27\u00af\13\27\3\27\3\27\3\27\3\27\3\27\3") + buf.write(u"\30\3\30\3\31\3\31\3\32\3\32\3\33\3\33\3\34\3\34\3\35") + buf.write(u"\3\35\3\36\3\36\3\36\3\37\3\37\3\37\3 \3 \3 \3!\3!\3") + buf.write(u"!\3\"\3\"\3\"\3#\3#\3#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&") + buf.write(u"\3&\3&\3&\3&\3&\3&\5&\u00e4\n&\3\'\3\'\3\'\3\'\5\'\u00ea") + buf.write(u"\n\'\3\'\3\'\3\'\5\'\u00ef\n\'\3(\6(\u00f2\n(\r(\16(") + buf.write(u"\u00f3\3)\3)\5)\u00f8\n)\3)\3)\3*\3*\5*\u00fe\n*\3*\3") + buf.write(u"*\3*\7*\u0103\n*\f*\16*\u0106\13*\3+\3+\3,\3,\4\u00a0") + buf.write(u"\u00ad\2-\3\3\5\4\7\5\t\6\13\7\r\b\17\t\21\n\23\13\25") + buf.write(u"\f\27\r\31\16\33\17\35\20\37\21!\22#\23%\24\'\25)\26") + buf.write(u"+\27-\30/\31\61\32\63\33\65\34\67\359\36;\37= ?!A\"C") + buf.write(u"#E$G%I&K\'M(O)Q\2S*U\2W\2\3\2\7\5\2\13\f\17\17\"\"\4") + buf.write(u"\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0113\2\3\3\2\2\2\2") + buf.write(u"\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13\3\2\2\2\2\r\3") + buf.write(u"\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3\2\2\2\2\25\3") + buf.write(u"\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2\2\2\2\35\3") + buf.write(u"\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2%\3\2\2\2") + buf.write(u"\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2\2/\3\2") + buf.write(u"\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67\3\2") + buf.write(u"\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2A\3") + buf.write(u"\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3\2\2\2\2") + buf.write(u"K\3\2\2\2\2M\3\2\2\2\2O\3\2\2\2\2S\3\2\2\2\3Y\3\2\2\2") + buf.write(u"\5[\3\2\2\2\7]\3\2\2\2\t_\3\2\2\2\13a\3\2\2\2\rc\3\2") + buf.write(u"\2\2\17f\3\2\2\2\21k\3\2\2\2\23o\3\2\2\2\25q\3\2\2\2") + buf.write(u"\27s\3\2\2\2\31u\3\2\2\2\33w\3\2\2\2\35z\3\2\2\2\37}") + buf.write(u"\3\2\2\2!\u0081\3\2\2\2#\u0083\3\2\2\2%\u008a\3\2\2\2") + buf.write(u"\'\u008c\3\2\2\2)\u0094\3\2\2\2+\u009a\3\2\2\2-\u00a7") + buf.write(u"\3\2\2\2/\u00b5\3\2\2\2\61\u00b7\3\2\2\2\63\u00b9\3\2") + buf.write(u"\2\2\65\u00bb\3\2\2\2\67\u00bd\3\2\2\29\u00bf\3\2\2\2") + buf.write(u";\u00c1\3\2\2\2=\u00c4\3\2\2\2?\u00c7\3\2\2\2A\u00ca") + buf.write(u"\3\2\2\2C\u00cd\3\2\2\2E\u00d0\3\2\2\2G\u00d3\3\2\2\2") + buf.write(u"I\u00d6\3\2\2\2K\u00e3\3\2\2\2M\u00ee\3\2\2\2O\u00f1") + buf.write(u"\3\2\2\2Q\u00f5\3\2\2\2S\u00fd\3\2\2\2U\u0107\3\2\2\2") + buf.write(u"W\u0109\3\2\2\2YZ\7*\2\2Z\4\3\2\2\2[\\\7+\2\2\\\6\3\2") + buf.write(u"\2\2]^\7.\2\2^\b\3\2\2\2_`\7]\2\2`\n\3\2\2\2ab\7_\2\2") + buf.write(u"b\f\3\2\2\2cd\7k\2\2de\7h\2\2e\16\3\2\2\2fg\7g\2\2gh") + buf.write(u"\7n\2\2hi\7u\2\2ij\7g\2\2j\20\3\2\2\2kl\7n\2\2lm\7g\2") + buf.write(u"\2mn\7v\2\2n\22\3\2\2\2op\7?\2\2p\24\3\2\2\2qr\7=\2\2") + buf.write(u"r\26\3\2\2\2st\7}\2\2t\30\3\2\2\2uv\7\177\2\2v\32\3\2") + buf.write(u"\2\2wx\7h\2\2xy\7p\2\2y\34\3\2\2\2z{\7/\2\2{|\7@\2\2") + buf.write(u"|\36\3\2\2\2}~\7f\2\2~\177\7g\2\2\177\u0080\7h\2\2\u0080") + buf.write(u" \3\2\2\2\u0081\u0082\7<\2\2\u0082\"\3\2\2\2\u0083\u0084") + buf.write(u"\7V\2\2\u0084\u0085\7g\2\2\u0085\u0086\7p\2\2\u0086\u0087") + buf.write(u"\7u\2\2\u0087\u0088\7q\2\2\u0088\u0089\7t\2\2\u0089$") + buf.write(u"\3\2\2\2\u008a\u008b\7a\2\2\u008b&\3\2\2\2\u008c\u008d") + buf.write(u"\7x\2\2\u008d\u008e\7\62\2\2\u008e\u008f\7\60\2\2\u008f") + buf.write(u"\u0090\7\62\2\2\u0090\u0091\7\60\2\2\u0091\u0092\7\64") + buf.write(u"\2\2\u0092(\3\2\2\2\u0093\u0095\t\2\2\2\u0094\u0093\3") + buf.write(u"\2\2\2\u0095\u0096\3\2\2\2\u0096\u0094\3\2\2\2\u0096") + buf.write(u"\u0097\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u0099\b\25\2") + buf.write(u"\2\u0099*\3\2\2\2\u009a\u009b\7\61\2\2\u009b\u009c\7") + buf.write(u"\61\2\2\u009c\u00a0\3\2\2\2\u009d\u009f\13\2\2\2\u009e") + buf.write(u"\u009d\3\2\2\2\u009f\u00a2\3\2\2\2\u00a0\u00a1\3\2\2") + buf.write(u"\2\u00a0\u009e\3\2\2\2\u00a1\u00a3\3\2\2\2\u00a2\u00a0") + buf.write(u"\3\2\2\2\u00a3\u00a4\7\f\2\2\u00a4\u00a5\3\2\2\2\u00a5") + buf.write(u"\u00a6\b\26\2\2\u00a6,\3\2\2\2\u00a7\u00a8\7\61\2\2\u00a8") + buf.write(u"\u00a9\7,\2\2\u00a9\u00ad\3\2\2\2\u00aa\u00ac\13\2\2") + buf.write(u"\2\u00ab\u00aa\3\2\2\2\u00ac\u00af\3\2\2\2\u00ad\u00ae") + buf.write(u"\3\2\2\2\u00ad\u00ab\3\2\2\2\u00ae\u00b0\3\2\2\2\u00af") + buf.write(u"\u00ad\3\2\2\2\u00b0\u00b1\7,\2\2\u00b1\u00b2\7\61\2") + buf.write(u"\2\u00b2\u00b3\3\2\2\2\u00b3\u00b4\b\27\2\2\u00b4.\3") + buf.write(u"\2\2\2\u00b5\u00b6\7,\2\2\u00b6\60\3\2\2\2\u00b7\u00b8") + buf.write(u"\7\61\2\2\u00b8\62\3\2\2\2\u00b9\u00ba\7-\2\2\u00ba\64") + buf.write(u"\3\2\2\2\u00bb\u00bc\7/\2\2\u00bc\66\3\2\2\2\u00bd\u00be") + buf.write(u"\7>\2\2\u00be8\3\2\2\2\u00bf\u00c0\7@\2\2\u00c0:\3\2") + buf.write(u"\2\2\u00c1\u00c2\7>\2\2\u00c2\u00c3\7?\2\2\u00c3<\3\2") + buf.write(u"\2\2\u00c4\u00c5\7@\2\2\u00c5\u00c6\7?\2\2\u00c6>\3\2") + buf.write(u"\2\2\u00c7\u00c8\7?\2\2\u00c8\u00c9\7?\2\2\u00c9@\3\2") + buf.write(u"\2\2\u00ca\u00cb\7#\2\2\u00cb\u00cc\7?\2\2\u00ccB\3\2") + buf.write(u"\2\2\u00cd\u00ce\7B\2\2\u00ce\u00cf\5S*\2\u00cfD\3\2") + buf.write(u"\2\2\u00d0\u00d1\7\'\2\2\u00d1\u00d2\5S*\2\u00d2F\3\2") + buf.write(u"\2\2\u00d3\u00d4\7\'\2\2\u00d4\u00d5\5O(\2\u00d5H\3\2") + buf.write(u"\2\2\u00d6\u00d7\7o\2\2\u00d7\u00d8\7w\2\2\u00d8\u00d9") + buf.write(u"\7v\2\2\u00d9J\3\2\2\2\u00da\u00db\7V\2\2\u00db\u00dc") + buf.write(u"\7t\2\2\u00dc\u00dd\7w\2\2\u00dd\u00e4\7g\2\2\u00de\u00df") + buf.write(u"\7H\2\2\u00df\u00e0\7c\2\2\u00e0\u00e1\7n\2\2\u00e1\u00e2") + buf.write(u"\7u\2\2\u00e2\u00e4\7g\2\2\u00e3\u00da\3\2\2\2\u00e3") + buf.write(u"\u00de\3\2\2\2\u00e4L\3\2\2\2\u00e5\u00e6\5O(\2\u00e6") + buf.write(u"\u00e7\7\60\2\2\u00e7\u00e9\5O(\2\u00e8\u00ea\5Q)\2\u00e9") + buf.write(u"\u00e8\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea\u00ef\3\2\2") + buf.write(u"\2\u00eb\u00ec\5O(\2\u00ec\u00ed\5Q)\2\u00ed\u00ef\3") + buf.write(u"\2\2\2\u00ee\u00e5\3\2\2\2\u00ee\u00eb\3\2\2\2\u00ef") + buf.write(u"N\3\2\2\2\u00f0\u00f2\5W,\2\u00f1\u00f0\3\2\2\2\u00f2") + buf.write(u"\u00f3\3\2\2\2\u00f3\u00f1\3\2\2\2\u00f3\u00f4\3\2\2") + buf.write(u"\2\u00f4P\3\2\2\2\u00f5\u00f7\t\3\2\2\u00f6\u00f8\t\4") + buf.write(u"\2\2\u00f7\u00f6\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") + buf.write(u"\3\2\2\2\u00f9\u00fa\5O(\2\u00faR\3\2\2\2\u00fb\u00fe") + buf.write(u"\7a\2\2\u00fc\u00fe\5U+\2\u00fd\u00fb\3\2\2\2\u00fd\u00fc") + buf.write(u"\3\2\2\2\u00fe\u0104\3\2\2\2\u00ff\u0103\7a\2\2\u0100") + buf.write(u"\u0103\5U+\2\u0101\u0103\5W,\2\u0102\u00ff\3\2\2\2\u0102") + buf.write(u"\u0100\3\2\2\2\u0102\u0101\3\2\2\2\u0103\u0106\3\2\2") + buf.write(u"\2\u0104\u0102\3\2\2\2\u0104\u0105\3\2\2\2\u0105T\3\2") + buf.write(u"\2\2\u0106\u0104\3\2\2\2\u0107\u0108\t\5\2\2\u0108V\3") + buf.write(u"\2\2\2\u0109\u010a\t\6\2\2\u010aX\3\2\2\2\16\2\u0096") + buf.write(u"\u00a0\u00ad\u00e3\u00e9\u00ee\u00f3\u00f7\u00fd\u0102") + buf.write(u"\u0104\3\b\2\2") + return buf.getvalue() + + +class RelayLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + SEMVER = 19 + WS = 20 + LINE_COMMENT = 21 + COMMENT = 22 + MUL = 23 + DIV = 24 + ADD = 25 + SUB = 26 + LT = 27 + GT = 28 + LE = 29 + GE = 30 + EQ = 31 + NE = 32 + GLOBAL_VAR = 33 + LOCAL_VAR = 34 + GRAPH_VAR = 35 + MUT = 36 + BOOL_LIT = 37 + FLOAT = 38 + NAT = 39 + CNAME = 40 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ u"DEFAULT_MODE" ] + + literalNames = [ u"", + u"'('", u"')'", u"','", u"'['", u"']'", u"'if'", u"'else'", + u"'let'", u"'='", u"';'", u"'{'", u"'}'", u"'fn'", u"'->'", + u"'def'", u"':'", u"'Tensor'", u"'_'", u"'v0.0.2'", u"'*'", + u"'/'", u"'+'", u"'-'", u"'<'", u"'>'", u"'<='", u"'>='", u"'=='", + u"'!='", u"'mut'" ] + + symbolicNames = [ u"", + u"SEMVER", u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", u"DIV", + u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", u"NE", u"GLOBAL_VAR", + u"LOCAL_VAR", u"GRAPH_VAR", u"MUT", u"BOOL_LIT", u"FLOAT", u"NAT", + u"CNAME" ] + + ruleNames = [ u"T__0", u"T__1", u"T__2", u"T__3", u"T__4", u"T__5", + u"T__6", u"T__7", u"T__8", u"T__9", u"T__10", u"T__11", + u"T__12", u"T__13", u"T__14", u"T__15", u"T__16", u"T__17", + u"SEMVER", u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", + u"DIV", u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", + u"NE", u"GLOBAL_VAR", u"LOCAL_VAR", u"GRAPH_VAR", u"MUT", + u"BOOL_LIT", u"FLOAT", u"NAT", u"EXP", u"CNAME", u"LETTER", + u"DIGIT" ] + + grammarFileName = u"Relay.g4" + + def __init__(self, input=None, output=sys.stdout): + super(RelayLexer, self).__init__(input, output=output) + self.checkVersion("4.7.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/python/tvm/relay/grammar/py2/RelayLexer.tokens b/python/tvm/relay/grammar/py2/RelayLexer.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayLexer.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py2/RelayParser.py b/python/tvm/relay/grammar/py2/RelayParser.py new file mode 100644 index 000000000000..77f56bf0545a --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayParser.py @@ -0,0 +1,2311 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from __future__ import print_function +from antlr4 import * +from io import StringIO +import sys + + +def serializedATN(): + with StringIO() as buf: + buf.write(u"\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3") + buf.write(u"*\u014c\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t") + buf.write(u"\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") + buf.write(u"\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4") + buf.write(u"\23\t\23\3\2\3\2\3\3\3\3\7\3+\n\3\f\3\16\3.\13\3\3\3") + buf.write(u"\5\3\61\n\3\3\3\3\3\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\6\4H\n\4\r") + buf.write(u"\4\16\4I\3\4\3\4\3\4\3\4\3\4\3\4\7\4R\n\4\f\4\16\4U\13") + buf.write(u"\4\5\4W\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\5\4d\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4n\n\4") + buf.write(u"\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write(u"\3\4\3\4\3\4\5\4\u0080\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3") + buf.write(u"\4\7\4\u0096\n\4\f\4\16\4\u0099\13\4\5\4\u009b\n\4\3") + buf.write(u"\4\7\4\u009e\n\4\f\4\16\4\u00a1\13\4\3\5\3\5\5\5\u00a5") + buf.write(u"\n\5\3\5\3\5\3\5\3\5\3\5\5\5\u00ac\n\5\3\5\3\5\3\6\3") + buf.write(u"\6\3\6\5\6\u00b3\n\6\3\6\3\6\3\6\3\6\3\6\5\6\u00ba\n") + buf.write(u"\6\3\6\3\6\3\7\3\7\3\7\3\7\3\7\3\7\5\7\u00c4\n\7\3\b") + buf.write(u"\3\b\3\b\7\b\u00c9\n\b\f\b\16\b\u00cc\13\b\5\b\u00ce") + buf.write(u"\n\b\3\t\3\t\3\t\5\t\u00d3\n\t\3\n\3\n\3\n\7\n\u00d8") + buf.write(u"\n\n\f\n\16\n\u00db\13\n\5\n\u00dd\n\n\3\13\3\13\3\13") + buf.write(u"\3\13\3\f\3\f\3\f\3\f\3\f\3\f\7\f\u00e9\n\f\f\f\16\f") + buf.write(u"\u00ec\13\f\3\f\3\f\5\f\u00f0\n\f\3\r\3\r\3\r\3\r\3\r") + buf.write(u"\3\r\3\r\3\r\3\r\3\r\3\r\6\r\u00fd\n\r\r\r\16\r\u00fe") + buf.write(u"\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\5\r") + buf.write(u"\u010d\n\r\3\r\3\r\3\r\3\r\7\r\u0113\n\r\f\r\16\r\u0116") + buf.write(u"\13\r\5\r\u0118\n\r\3\r\3\r\3\r\3\r\3\r\5\r\u011f\n\r") + buf.write(u"\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3\16\3") + buf.write(u"\16\6\16\u012c\n\16\r\16\16\16\u012d\3\16\3\16\5\16\u0132") + buf.write(u"\n\16\3\17\3\17\3\17\3\17\3\17\5\17\u0139\n\17\3\20\3") + buf.write(u"\20\3\21\3\21\3\21\3\21\3\22\3\22\3\22\5\22\u0144\n\22") + buf.write(u"\3\23\3\23\3\23\3\23\5\23\u014a\n\23\3\23\2\3\6\24\2") + buf.write(u"\4\6\b\n\f\16\20\22\24\26\30\32\34\36 \"$\2\6\3\2\31") + buf.write(u"\32\3\2\33\34\3\2\35 \3\2!\"\2\u0175\2&\3\2\2\2\4(\3") + buf.write(u"\2\2\2\6\177\3\2\2\2\b\u00a2\3\2\2\2\n\u00af\3\2\2\2") + buf.write(u"\f\u00c3\3\2\2\2\16\u00cd\3\2\2\2\20\u00cf\3\2\2\2\22") + buf.write(u"\u00dc\3\2\2\2\24\u00de\3\2\2\2\26\u00ef\3\2\2\2\30\u011e") + buf.write(u"\3\2\2\2\32\u0131\3\2\2\2\34\u0138\3\2\2\2\36\u013a\3") + buf.write(u"\2\2\2 \u013c\3\2\2\2\"\u0143\3\2\2\2$\u0149\3\2\2\2") + buf.write(u"&\'\7*\2\2\'\3\3\2\2\2(\60\7\25\2\2)+\5\n\6\2*)\3\2\2") + buf.write(u"\2+.\3\2\2\2,*\3\2\2\2,-\3\2\2\2-\61\3\2\2\2.,\3\2\2") + buf.write(u"\2/\61\5\6\4\2\60,\3\2\2\2\60/\3\2\2\2\61\62\3\2\2\2") + buf.write(u"\62\63\7\2\2\3\63\5\3\2\2\2\64\65\b\4\1\2\65\66\7\3\2") + buf.write(u"\2\66\67\5\6\4\2\678\7\4\2\28\u0080\3\2\2\29:\7\34\2") + buf.write(u"\2:\u0080\5\6\4\23;\u0080\5\b\5\2<=\7\3\2\2=\u0080\7") + buf.write(u"\4\2\2>?\7\3\2\2?@\5\6\4\2@A\7\5\2\2AB\7\4\2\2B\u0080") + buf.write(u"\3\2\2\2CD\7\3\2\2DG\5\6\4\2EF\7\5\2\2FH\5\6\4\2GE\3") + buf.write(u"\2\2\2HI\3\2\2\2IG\3\2\2\2IJ\3\2\2\2JK\3\2\2\2KL\7\4") + buf.write(u"\2\2L\u0080\3\2\2\2MV\7\6\2\2NS\5\6\4\2OP\7\5\2\2PR\5") + buf.write(u"\6\4\2QO\3\2\2\2RU\3\2\2\2SQ\3\2\2\2ST\3\2\2\2TW\3\2") + buf.write(u"\2\2US\3\2\2\2VN\3\2\2\2VW\3\2\2\2WX\3\2\2\2X\u0080\7") + buf.write(u"\7\2\2YZ\7\b\2\2Z[\7\3\2\2[\\\5\6\4\2\\]\7\4\2\2]^\5") + buf.write(u" \21\2^_\7\t\2\2_`\5 \21\2`\u0080\3\2\2\2ac\7\n\2\2b") + buf.write(u"d\7&\2\2cb\3\2\2\2cd\3\2\2\2de\3\2\2\2ef\5\20\t\2fg\7") + buf.write(u"\13\2\2gh\5\6\4\2hi\7\f\2\2ij\5\6\4\bj\u0080\3\2\2\2") + buf.write(u"km\7\n\2\2ln\7&\2\2ml\3\2\2\2mn\3\2\2\2no\3\2\2\2op\5") + buf.write(u"\20\t\2pq\7\13\2\2qr\7\r\2\2rs\5\6\4\2st\7\16\2\2tu\7") + buf.write(u"\f\2\2uv\5\6\4\7v\u0080\3\2\2\2wx\5$\23\2xy\7\13\2\2") + buf.write(u"yz\5\6\4\2z{\7\f\2\2{|\5\6\4\5|\u0080\3\2\2\2}\u0080") + buf.write(u"\5$\23\2~\u0080\5\"\22\2\177\64\3\2\2\2\1779\3\2\2\2") + buf.write(u"\177;\3\2\2\2\177<\3\2\2\2\177>\3\2\2\2\177C\3\2\2\2") + buf.write(u"\177M\3\2\2\2\177Y\3\2\2\2\177a\3\2\2\2\177k\3\2\2\2") + buf.write(u"\177w\3\2\2\2\177}\3\2\2\2\177~\3\2\2\2\u0080\u009f\3") + buf.write(u"\2\2\2\u0081\u0082\f\22\2\2\u0082\u0083\t\2\2\2\u0083") + buf.write(u"\u009e\5\6\4\23\u0084\u0085\f\21\2\2\u0085\u0086\t\3") + buf.write(u"\2\2\u0086\u009e\5\6\4\22\u0087\u0088\f\20\2\2\u0088") + buf.write(u"\u0089\t\4\2\2\u0089\u009e\5\6\4\21\u008a\u008b\f\17") + buf.write(u"\2\2\u008b\u008c\t\5\2\2\u008c\u009e\5\6\4\20\u008d\u008e") + buf.write(u"\f\6\2\2\u008e\u008f\7\f\2\2\u008f\u009e\5\6\4\7\u0090") + buf.write(u"\u0091\f\24\2\2\u0091\u009a\7\3\2\2\u0092\u0097\5\6\4") + buf.write(u"\2\u0093\u0094\7\5\2\2\u0094\u0096\5\6\4\2\u0095\u0093") + buf.write(u"\3\2\2\2\u0096\u0099\3\2\2\2\u0097\u0095\3\2\2\2\u0097") + buf.write(u"\u0098\3\2\2\2\u0098\u009b\3\2\2\2\u0099\u0097\3\2\2") + buf.write(u"\2\u009a\u0092\3\2\2\2\u009a\u009b\3\2\2\2\u009b\u009c") + buf.write(u"\3\2\2\2\u009c\u009e\7\4\2\2\u009d\u0081\3\2\2\2\u009d") + buf.write(u"\u0084\3\2\2\2\u009d\u0087\3\2\2\2\u009d\u008a\3\2\2") + buf.write(u"\2\u009d\u008d\3\2\2\2\u009d\u0090\3\2\2\2\u009e\u00a1") + buf.write(u"\3\2\2\2\u009f\u009d\3\2\2\2\u009f\u00a0\3\2\2\2\u00a0") + buf.write(u"\7\3\2\2\2\u00a1\u009f\3\2\2\2\u00a2\u00a4\7\17\2\2\u00a3") + buf.write(u"\u00a5\5\26\f\2\u00a4\u00a3\3\2\2\2\u00a4\u00a5\3\2\2") + buf.write(u"\2\u00a5\u00a6\3\2\2\2\u00a6\u00a7\7\3\2\2\u00a7\u00a8") + buf.write(u"\5\f\7\2\u00a8\u00ab\7\4\2\2\u00a9\u00aa\7\20\2\2\u00aa") + buf.write(u"\u00ac\5\30\r\2\u00ab\u00a9\3\2\2\2\u00ab\u00ac\3\2\2") + buf.write(u"\2\u00ac\u00ad\3\2\2\2\u00ad\u00ae\5 \21\2\u00ae\t\3") + buf.write(u"\2\2\2\u00af\u00b0\7\21\2\2\u00b0\u00b2\5$\23\2\u00b1") + buf.write(u"\u00b3\5\26\f\2\u00b2\u00b1\3\2\2\2\u00b2\u00b3\3\2\2") + buf.write(u"\2\u00b3\u00b4\3\2\2\2\u00b4\u00b5\7\3\2\2\u00b5\u00b6") + buf.write(u"\5\f\7\2\u00b6\u00b9\7\4\2\2\u00b7\u00b8\7\20\2\2\u00b8") + buf.write(u"\u00ba\5\30\r\2\u00b9\u00b7\3\2\2\2\u00b9\u00ba\3\2\2") + buf.write(u"\2\u00ba\u00bb\3\2\2\2\u00bb\u00bc\5 \21\2\u00bc\13\3") + buf.write(u"\2\2\2\u00bd\u00c4\5\16\b\2\u00be\u00c4\5\22\n\2\u00bf") + buf.write(u"\u00c0\5\16\b\2\u00c0\u00c1\7\5\2\2\u00c1\u00c2\5\22") + buf.write(u"\n\2\u00c2\u00c4\3\2\2\2\u00c3\u00bd\3\2\2\2\u00c3\u00be") + buf.write(u"\3\2\2\2\u00c3\u00bf\3\2\2\2\u00c4\r\3\2\2\2\u00c5\u00ca") + buf.write(u"\5\20\t\2\u00c6\u00c7\7\5\2\2\u00c7\u00c9\5\20\t\2\u00c8") + buf.write(u"\u00c6\3\2\2\2\u00c9\u00cc\3\2\2\2\u00ca\u00c8\3\2\2") + buf.write(u"\2\u00ca\u00cb\3\2\2\2\u00cb\u00ce\3\2\2\2\u00cc\u00ca") + buf.write(u"\3\2\2\2\u00cd\u00c5\3\2\2\2\u00cd\u00ce\3\2\2\2\u00ce") + buf.write(u"\17\3\2\2\2\u00cf\u00d2\5$\23\2\u00d0\u00d1\7\22\2\2") + buf.write(u"\u00d1\u00d3\5\30\r\2\u00d2\u00d0\3\2\2\2\u00d2\u00d3") + buf.write(u"\3\2\2\2\u00d3\21\3\2\2\2\u00d4\u00d9\5\24\13\2\u00d5") + buf.write(u"\u00d6\7\5\2\2\u00d6\u00d8\5\24\13\2\u00d7\u00d5\3\2") + buf.write(u"\2\2\u00d8\u00db\3\2\2\2\u00d9\u00d7\3\2\2\2\u00d9\u00da") + buf.write(u"\3\2\2\2\u00da\u00dd\3\2\2\2\u00db\u00d9\3\2\2\2\u00dc") + buf.write(u"\u00d4\3\2\2\2\u00dc\u00dd\3\2\2\2\u00dd\23\3\2\2\2\u00de") + buf.write(u"\u00df\7*\2\2\u00df\u00e0\7\13\2\2\u00e0\u00e1\5\6\4") + buf.write(u"\2\u00e1\25\3\2\2\2\u00e2\u00e3\7\6\2\2\u00e3\u00f0\7") + buf.write(u"\7\2\2\u00e4\u00e5\7\6\2\2\u00e5\u00ea\5$\23\2\u00e6") + buf.write(u"\u00e7\7\5\2\2\u00e7\u00e9\5$\23\2\u00e8\u00e6\3\2\2") + buf.write(u"\2\u00e9\u00ec\3\2\2\2\u00ea\u00e8\3\2\2\2\u00ea\u00eb") + buf.write(u"\3\2\2\2\u00eb\u00ed\3\2\2\2\u00ec\u00ea\3\2\2\2\u00ed") + buf.write(u"\u00ee\7\7\2\2\u00ee\u00f0\3\2\2\2\u00ef\u00e2\3\2\2") + buf.write(u"\2\u00ef\u00e4\3\2\2\2\u00f0\27\3\2\2\2\u00f1\u00f2\7") + buf.write(u"\3\2\2\u00f2\u011f\7\4\2\2\u00f3\u00f4\7\3\2\2\u00f4") + buf.write(u"\u00f5\5\30\r\2\u00f5\u00f6\7\5\2\2\u00f6\u00f7\7\4\2") + buf.write(u"\2\u00f7\u011f\3\2\2\2\u00f8\u00f9\7\3\2\2\u00f9\u00fc") + buf.write(u"\5\30\r\2\u00fa\u00fb\7\5\2\2\u00fb\u00fd\5\30\r\2\u00fc") + buf.write(u"\u00fa\3\2\2\2\u00fd\u00fe\3\2\2\2\u00fe\u00fc\3\2\2") + buf.write(u"\2\u00fe\u00ff\3\2\2\2\u00ff\u0100\3\2\2\2\u0100\u0101") + buf.write(u"\7\4\2\2\u0101\u011f\3\2\2\2\u0102\u011f\5\36\20\2\u0103") + buf.write(u"\u0104\7\23\2\2\u0104\u0105\7\6\2\2\u0105\u0106\5\32") + buf.write(u"\16\2\u0106\u0107\7\5\2\2\u0107\u0108\5\30\r\2\u0108") + buf.write(u"\u0109\7\7\2\2\u0109\u011f\3\2\2\2\u010a\u010c\7\17\2") + buf.write(u"\2\u010b\u010d\5\26\f\2\u010c\u010b\3\2\2\2\u010c\u010d") + buf.write(u"\3\2\2\2\u010d\u010e\3\2\2\2\u010e\u0117\7\3\2\2\u010f") + buf.write(u"\u0114\5\30\r\2\u0110\u0111\7\5\2\2\u0111\u0113\5\30") + buf.write(u"\r\2\u0112\u0110\3\2\2\2\u0113\u0116\3\2\2\2\u0114\u0112") + buf.write(u"\3\2\2\2\u0114\u0115\3\2\2\2\u0115\u0118\3\2\2\2\u0116") + buf.write(u"\u0114\3\2\2\2\u0117\u010f\3\2\2\2\u0117\u0118\3\2\2") + buf.write(u"\2\u0118\u0119\3\2\2\2\u0119\u011a\7\4\2\2\u011a\u011b") + buf.write(u"\7\20\2\2\u011b\u011f\5\30\r\2\u011c\u011f\7\24\2\2\u011d") + buf.write(u"\u011f\7)\2\2\u011e\u00f1\3\2\2\2\u011e\u00f3\3\2\2\2") + buf.write(u"\u011e\u00f8\3\2\2\2\u011e\u0102\3\2\2\2\u011e\u0103") + buf.write(u"\3\2\2\2\u011e\u010a\3\2\2\2\u011e\u011c\3\2\2\2\u011e") + buf.write(u"\u011d\3\2\2\2\u011f\31\3\2\2\2\u0120\u0121\7\3\2\2\u0121") + buf.write(u"\u0132\7\4\2\2\u0122\u0123\7\3\2\2\u0123\u0124\5\34\17") + buf.write(u"\2\u0124\u0125\7\5\2\2\u0125\u0126\7\4\2\2\u0126\u0132") + buf.write(u"\3\2\2\2\u0127\u0128\7\3\2\2\u0128\u012b\5\34\17\2\u0129") + buf.write(u"\u012a\7\5\2\2\u012a\u012c\5\34\17\2\u012b\u0129\3\2") + buf.write(u"\2\2\u012c\u012d\3\2\2\2\u012d\u012b\3\2\2\2\u012d\u012e") + buf.write(u"\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0130\7\4\2\2\u0130") + buf.write(u"\u0132\3\2\2\2\u0131\u0120\3\2\2\2\u0131\u0122\3\2\2") + buf.write(u"\2\u0131\u0127\3\2\2\2\u0132\33\3\2\2\2\u0133\u0134\7") + buf.write(u"\3\2\2\u0134\u0135\5\34\17\2\u0135\u0136\7\4\2\2\u0136") + buf.write(u"\u0139\3\2\2\2\u0137\u0139\7)\2\2\u0138\u0133\3\2\2\2") + buf.write(u"\u0138\u0137\3\2\2\2\u0139\35\3\2\2\2\u013a\u013b\7*") + buf.write(u"\2\2\u013b\37\3\2\2\2\u013c\u013d\7\r\2\2\u013d\u013e") + buf.write(u"\5\6\4\2\u013e\u013f\7\16\2\2\u013f!\3\2\2\2\u0140\u0144") + buf.write(u"\7(\2\2\u0141\u0144\7)\2\2\u0142\u0144\7\'\2\2\u0143") + buf.write(u"\u0140\3\2\2\2\u0143\u0141\3\2\2\2\u0143\u0142\3\2\2") + buf.write(u"\2\u0144#\3\2\2\2\u0145\u014a\5\2\2\2\u0146\u014a\7#") + buf.write(u"\2\2\u0147\u014a\7$\2\2\u0148\u014a\7%\2\2\u0149\u0145") + buf.write(u"\3\2\2\2\u0149\u0146\3\2\2\2\u0149\u0147\3\2\2\2\u0149") + buf.write(u"\u0148\3\2\2\2\u014a%\3\2\2\2$,\60ISVcm\177\u0097\u009a") + buf.write(u"\u009d\u009f\u00a4\u00ab\u00b2\u00b9\u00c3\u00ca\u00cd") + buf.write(u"\u00d2\u00d9\u00dc\u00ea\u00ef\u00fe\u010c\u0114\u0117") + buf.write(u"\u011e\u012d\u0131\u0138\u0143\u0149") + return buf.getvalue() + + +class RelayParser ( Parser ): + + grammarFileName = "Relay.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ u"", u"'('", u"')'", u"','", u"'['", u"']'", + u"'if'", u"'else'", u"'let'", u"'='", u"';'", u"'{'", + u"'}'", u"'fn'", u"'->'", u"'def'", u"':'", u"'Tensor'", + u"'_'", u"'v0.0.2'", u"", u"", u"", + u"'*'", u"'/'", u"'+'", u"'-'", u"'<'", u"'>'", u"'<='", + u"'>='", u"'=='", u"'!='", u"", u"", + u"", u"'mut'" ] + + symbolicNames = [ u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"", + u"", u"", u"", u"SEMVER", + u"WS", u"LINE_COMMENT", u"COMMENT", u"MUL", u"DIV", + u"ADD", u"SUB", u"LT", u"GT", u"LE", u"GE", u"EQ", + u"NE", u"GLOBAL_VAR", u"LOCAL_VAR", u"GRAPH_VAR", + u"MUT", u"BOOL_LIT", u"FLOAT", u"NAT", u"CNAME" ] + + RULE_opIdent = 0 + RULE_prog = 1 + RULE_expr = 2 + RULE_func = 3 + RULE_defn = 4 + RULE_argList = 5 + RULE_varList = 6 + RULE_var = 7 + RULE_attrList = 8 + RULE_attr = 9 + RULE_typeParamSeq = 10 + RULE_type_ = 11 + RULE_shapeSeq = 12 + RULE_shape = 13 + RULE_typeIdent = 14 + RULE_body = 15 + RULE_scalar = 16 + RULE_ident = 17 + + ruleNames = [ u"opIdent", u"prog", u"expr", u"func", u"defn", u"argList", + u"varList", u"var", u"attrList", u"attr", u"typeParamSeq", + u"type_", u"shapeSeq", u"shape", u"typeIdent", u"body", + u"scalar", u"ident" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + SEMVER=19 + WS=20 + LINE_COMMENT=21 + COMMENT=22 + MUL=23 + DIV=24 + ADD=25 + SUB=26 + LT=27 + GT=28 + LE=29 + GE=30 + EQ=31 + NE=32 + GLOBAL_VAR=33 + LOCAL_VAR=34 + GRAPH_VAR=35 + MUT=36 + BOOL_LIT=37 + FLOAT=38 + NAT=39 + CNAME=40 + + def __init__(self, input, output=sys.stdout): + super(RelayParser, self).__init__(input, output=output) + self.checkVersion("4.7.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class OpIdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.OpIdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_opIdent + + def accept(self, visitor): + if hasattr(visitor, "visitOpIdent"): + return visitor.visitOpIdent(self) + else: + return visitor.visitChildren(self) + + + + + def opIdent(self): + + localctx = RelayParser.OpIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_opIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 36 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ProgContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ProgContext, self).__init__(parent, invokingState) + self.parser = parser + + def SEMVER(self): + return self.getToken(RelayParser.SEMVER, 0) + + def EOF(self): + return self.getToken(RelayParser.EOF, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def defn(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.DefnContext) + else: + return self.getTypedRuleContext(RelayParser.DefnContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_prog + + def accept(self, visitor): + if hasattr(visitor, "visitProg"): + return visitor.visitProg(self) + else: + return visitor.visitChildren(self) + + + + + def prog(self): + + localctx = RelayParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 38 + self.match(RelayParser.SEMVER) + self.state = 46 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.EOF, RelayParser.T__14]: + self.state = 42 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__14: + self.state = 39 + self.defn() + self.state = 44 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [RelayParser.T__0, RelayParser.T__3, RelayParser.T__5, RelayParser.T__7, RelayParser.T__12, RelayParser.SUB, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.BOOL_LIT, RelayParser.FLOAT, RelayParser.NAT, RelayParser.CNAME]: + self.state = 45 + self.expr(0) + pass + else: + raise NoViableAltException(self) + + self.state = 48 + self.match(RelayParser.EOF) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ExprContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_expr + + + def copyFrom(self, ctx): + super(RelayParser.ExprContext, self).copyFrom(ctx) + + + class IdentExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.IdentExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitIdentExpr"): + return visitor.visitIdentExpr(self) + else: + return visitor.visitChildren(self) + + + class CallContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.CallContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitCall"): + return visitor.visitCall(self) + else: + return visitor.visitChildren(self) + + + class NegContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.NegContext, self).__init__(parser) + self.copyFrom(ctx) + + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitNeg"): + return visitor.visitNeg(self) + else: + return visitor.visitChildren(self) + + + class TupleContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.TupleContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTuple"): + return visitor.visitTuple(self) + else: + return visitor.visitChildren(self) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.ParensContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitParens"): + return visitor.visitParens(self) + else: + return visitor.visitChildren(self) + + + class FuncExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.FuncExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def func(self): + return self.getTypedRuleContext(RelayParser.FuncContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitFuncExpr"): + return visitor.visitFuncExpr(self) + else: + return visitor.visitChildren(self) + + + class ScalarExprContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.ScalarExprContext, self).__init__(parser) + self.copyFrom(ctx) + + def scalar(self): + return self.getTypedRuleContext(RelayParser.ScalarContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitScalarExpr"): + return visitor.visitScalarExpr(self) + else: + return visitor.visitChildren(self) + + + class LetContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.LetContext, self).__init__(parser) + self.copyFrom(ctx) + + def var(self): + return self.getTypedRuleContext(RelayParser.VarContext,0) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUT(self): + return self.getToken(RelayParser.MUT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitLet"): + return visitor.visitLet(self) + else: + return visitor.visitChildren(self) + + + class TensorContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.TensorContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTensor"): + return visitor.visitTensor(self) + else: + return visitor.visitChildren(self) + + + class IfElseContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.IfElseContext, self).__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + def body(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.BodyContext) + else: + return self.getTypedRuleContext(RelayParser.BodyContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitIfElse"): + return visitor.visitIfElse(self) + else: + return visitor.visitChildren(self) + + + class GraphContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.GraphContext, self).__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitGraph"): + return visitor.visitGraph(self) + else: + return visitor.visitChildren(self) + + + class BinOpContext(ExprContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ExprContext) + super(RelayParser.BinOpContext, self).__init__(parser) + self.op = None # Token + self.copyFrom(ctx) + + def expr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUL(self): + return self.getToken(RelayParser.MUL, 0) + def DIV(self): + return self.getToken(RelayParser.DIV, 0) + def ADD(self): + return self.getToken(RelayParser.ADD, 0) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def LT(self): + return self.getToken(RelayParser.LT, 0) + def GT(self): + return self.getToken(RelayParser.GT, 0) + def LE(self): + return self.getToken(RelayParser.LE, 0) + def GE(self): + return self.getToken(RelayParser.GE, 0) + def EQ(self): + return self.getToken(RelayParser.EQ, 0) + def NE(self): + return self.getToken(RelayParser.NE, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitBinOp"): + return visitor.visitBinOp(self) + else: + return visitor.visitChildren(self) + + + + def expr(self, _p=0): + _parentctx = self._ctx + _parentState = self.state + localctx = RelayParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 4 + self.enterRecursionRule(localctx, 4, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 125 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + if la_ == 1: + localctx = RelayParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 51 + self.match(RelayParser.T__0) + self.state = 52 + self.expr(0) + self.state = 53 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.NegContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 55 + self.match(RelayParser.SUB) + self.state = 56 + self.expr(17) + pass + + elif la_ == 3: + localctx = RelayParser.FuncExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 57 + self.func() + pass + + elif la_ == 4: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 58 + self.match(RelayParser.T__0) + self.state = 59 + self.match(RelayParser.T__1) + pass + + elif la_ == 5: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 60 + self.match(RelayParser.T__0) + self.state = 61 + self.expr(0) + self.state = 62 + self.match(RelayParser.T__2) + self.state = 63 + self.match(RelayParser.T__1) + pass + + elif la_ == 6: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 65 + self.match(RelayParser.T__0) + self.state = 66 + self.expr(0) + self.state = 69 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 67 + self.match(RelayParser.T__2) + self.state = 68 + self.expr(0) + self.state = 71 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 73 + self.match(RelayParser.T__1) + pass + + elif la_ == 7: + localctx = RelayParser.TensorContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 75 + self.match(RelayParser.T__3) + self.state = 84 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 76 + self.expr(0) + self.state = 81 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 77 + self.match(RelayParser.T__2) + self.state = 78 + self.expr(0) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 86 + self.match(RelayParser.T__4) + pass + + elif la_ == 8: + localctx = RelayParser.IfElseContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 87 + self.match(RelayParser.T__5) + self.state = 88 + self.match(RelayParser.T__0) + self.state = 89 + self.expr(0) + self.state = 90 + self.match(RelayParser.T__1) + self.state = 91 + self.body() + self.state = 92 + self.match(RelayParser.T__6) + self.state = 93 + self.body() + pass + + elif la_ == 9: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 95 + self.match(RelayParser.T__7) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 96 + self.match(RelayParser.MUT) + + + self.state = 99 + self.var() + self.state = 100 + self.match(RelayParser.T__8) + self.state = 101 + self.expr(0) + self.state = 102 + self.match(RelayParser.T__9) + self.state = 103 + self.expr(6) + pass + + elif la_ == 10: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 105 + self.match(RelayParser.T__7) + self.state = 107 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 106 + self.match(RelayParser.MUT) + + + self.state = 109 + self.var() + self.state = 110 + self.match(RelayParser.T__8) + self.state = 111 + self.match(RelayParser.T__10) + self.state = 112 + self.expr(0) + self.state = 113 + self.match(RelayParser.T__11) + self.state = 114 + self.match(RelayParser.T__9) + self.state = 115 + self.expr(5) + pass + + elif la_ == 11: + localctx = RelayParser.GraphContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 117 + self.ident() + self.state = 118 + self.match(RelayParser.T__8) + self.state = 119 + self.expr(0) + self.state = 120 + self.match(RelayParser.T__9) + self.state = 121 + self.expr(3) + pass + + elif la_ == 12: + localctx = RelayParser.IdentExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 123 + self.ident() + pass + + elif la_ == 13: + localctx = RelayParser.ScalarExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 124 + self.scalar() + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 157 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 155 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,10,self._ctx) + if la_ == 1: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 127 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 128 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.MUL or _la==RelayParser.DIV): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 129 + self.expr(17) + pass + + elif la_ == 2: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 130 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 131 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.ADD or _la==RelayParser.SUB): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 132 + self.expr(16) + pass + + elif la_ == 3: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 133 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 134 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.expr(15) + pass + + elif la_ == 4: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 136 + if not self.precpred(self._ctx, 13): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 13)") + self.state = 137 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.EQ or _la==RelayParser.NE): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 138 + self.expr(14) + pass + + elif la_ == 5: + localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 139 + if not self.precpred(self._ctx, 4): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 140 + self.match(RelayParser.T__9) + self.state = 141 + self.expr(5) + pass + + elif la_ == 6: + localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 142 + if not self.precpred(self._ctx, 18): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") + self.state = 143 + self.match(RelayParser.T__0) + self.state = 152 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 144 + self.expr(0) + self.state = 149 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 145 + self.match(RelayParser.T__2) + self.state = 146 + self.expr(0) + self.state = 151 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 154 + self.match(RelayParser.T__1) + pass + + + self.state = 159 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class FuncContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.FuncContext, self).__init__(parent, invokingState) + self.parser = parser + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_func + + def accept(self, visitor): + if hasattr(visitor, "visitFunc"): + return visitor.visitFunc(self) + else: + return visitor.visitChildren(self) + + + + + def func(self): + + localctx = RelayParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_func) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 160 + self.match(RelayParser.T__12) + self.state = 162 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 161 + self.typeParamSeq() + + + self.state = 164 + self.match(RelayParser.T__0) + self.state = 165 + self.argList() + self.state = 166 + self.match(RelayParser.T__1) + self.state = 169 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 167 + self.match(RelayParser.T__13) + self.state = 168 + self.type_() + + + self.state = 171 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DefnContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.DefnContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_defn + + def accept(self, visitor): + if hasattr(visitor, "visitDefn"): + return visitor.visitDefn(self) + else: + return visitor.visitChildren(self) + + + + + def defn(self): + + localctx = RelayParser.DefnContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_defn) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 173 + self.match(RelayParser.T__14) + self.state = 174 + self.ident() + self.state = 176 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 175 + self.typeParamSeq() + + + self.state = 178 + self.match(RelayParser.T__0) + self.state = 179 + self.argList() + self.state = 180 + self.match(RelayParser.T__1) + self.state = 183 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 181 + self.match(RelayParser.T__13) + self.state = 182 + self.type_() + + + self.state = 185 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ArgListContext, self).__init__(parent, invokingState) + self.parser = parser + + def varList(self): + return self.getTypedRuleContext(RelayParser.VarListContext,0) + + + def attrList(self): + return self.getTypedRuleContext(RelayParser.AttrListContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_argList + + def accept(self, visitor): + if hasattr(visitor, "visitArgList"): + return visitor.visitArgList(self) + else: + return visitor.visitChildren(self) + + + + + def argList(self): + + localctx = RelayParser.ArgListContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_argList) + try: + self.state = 193 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,16,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 187 + self.varList() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 188 + self.attrList() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 189 + self.varList() + self.state = 190 + self.match(RelayParser.T__2) + self.state = 191 + self.attrList() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.VarListContext, self).__init__(parent, invokingState) + self.parser = parser + + def var(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.VarContext) + else: + return self.getTypedRuleContext(RelayParser.VarContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_varList + + def accept(self, visitor): + if hasattr(visitor, "visitVarList"): + return visitor.visitVarList(self) + else: + return visitor.visitChildren(self) + + + + + def varList(self): + + localctx = RelayParser.VarListContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_varList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.CNAME))) != 0): + self.state = 195 + self.var() + self.state = 200 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 196 + self.match(RelayParser.T__2) + self.state = 197 + self.var() + self.state = 202 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.VarContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_var + + def accept(self, visitor): + if hasattr(visitor, "visitVar"): + return visitor.visitVar(self) + else: + return visitor.visitChildren(self) + + + + + def var(self): + + localctx = RelayParser.VarContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_var) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 205 + self.ident() + self.state = 208 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__15: + self.state = 206 + self.match(RelayParser.T__15) + self.state = 207 + self.type_() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrListContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.AttrListContext, self).__init__(parent, invokingState) + self.parser = parser + + def attr(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.AttrContext) + else: + return self.getTypedRuleContext(RelayParser.AttrContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_attrList + + def accept(self, visitor): + if hasattr(visitor, "visitAttrList"): + return visitor.visitAttrList(self) + else: + return visitor.visitChildren(self) + + + + + def attrList(self): + + localctx = RelayParser.AttrListContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_attrList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.CNAME: + self.state = 210 + self.attr() + self.state = 215 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 211 + self.match(RelayParser.T__2) + self.state = 212 + self.attr() + self.state = 217 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.AttrContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_attr + + def accept(self, visitor): + if hasattr(visitor, "visitAttr"): + return visitor.visitAttr(self) + else: + return visitor.visitChildren(self) + + + + + def attr(self): + + localctx = RelayParser.AttrContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_attr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 220 + self.match(RelayParser.CNAME) + self.state = 221 + self.match(RelayParser.T__8) + self.state = 222 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeParamSeqContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.TypeParamSeqContext, self).__init__(parent, invokingState) + self.parser = parser + + def ident(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.IdentContext) + else: + return self.getTypedRuleContext(RelayParser.IdentContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_typeParamSeq + + def accept(self, visitor): + if hasattr(visitor, "visitTypeParamSeq"): + return visitor.visitTypeParamSeq(self) + else: + return visitor.visitChildren(self) + + + + + def typeParamSeq(self): + + localctx = RelayParser.TypeParamSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_typeParamSeq) + self._la = 0 # Token type + try: + self.state = 237 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,23,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 224 + self.match(RelayParser.T__3) + self.state = 225 + self.match(RelayParser.T__4) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 226 + self.match(RelayParser.T__3) + self.state = 227 + self.ident() + self.state = 232 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 228 + self.match(RelayParser.T__2) + self.state = 229 + self.ident() + self.state = 234 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 235 + self.match(RelayParser.T__4) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Type_Context(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.Type_Context, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_type_ + + + def copyFrom(self, ctx): + super(RelayParser.Type_Context, self).copyFrom(ctx) + + + + class IntTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.IntTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitIntType"): + return visitor.visitIntType(self) + else: + return visitor.visitChildren(self) + + + class TupleTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TupleTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def type_(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + + def accept(self, visitor): + if hasattr(visitor, "visitTupleType"): + return visitor.visitTupleType(self) + else: + return visitor.visitChildren(self) + + + class TypeIdentTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TypeIdentTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def typeIdent(self): + return self.getTypedRuleContext(RelayParser.TypeIdentContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitTypeIdentType"): + return visitor.visitTypeIdentType(self) + else: + return visitor.visitChildren(self) + + + class IncompleteTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.IncompleteTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + + def accept(self, visitor): + if hasattr(visitor, "visitIncompleteType"): + return visitor.visitIncompleteType(self) + else: + return visitor.visitChildren(self) + + + class TensorTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.TensorTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def shapeSeq(self): + return self.getTypedRuleContext(RelayParser.ShapeSeqContext,0) + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitTensorType"): + return visitor.visitTensorType(self) + else: + return visitor.visitChildren(self) + + + class FuncTypeContext(Type_Context): + + def __init__(self, parser, ctx): # actually a RelayParser.Type_Context) + super(RelayParser.FuncTypeContext, self).__init__(parser) + self.copyFrom(ctx) + + def type_(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitFuncType"): + return visitor.visitFuncType(self) + else: + return visitor.visitChildren(self) + + + + def type_(self): + + localctx = RelayParser.Type_Context(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_type_) + self._la = 0 # Token type + try: + self.state = 284 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 239 + self.match(RelayParser.T__0) + self.state = 240 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 241 + self.match(RelayParser.T__0) + self.state = 242 + self.type_() + self.state = 243 + self.match(RelayParser.T__2) + self.state = 244 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.match(RelayParser.T__0) + self.state = 247 + self.type_() + self.state = 250 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 248 + self.match(RelayParser.T__2) + self.state = 249 + self.type_() + self.state = 252 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 254 + self.match(RelayParser.T__1) + pass + + elif la_ == 4: + localctx = RelayParser.TypeIdentTypeContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 256 + self.typeIdent() + pass + + elif la_ == 5: + localctx = RelayParser.TensorTypeContext(self, localctx) + self.enterOuterAlt(localctx, 5) + self.state = 257 + self.match(RelayParser.T__16) + self.state = 258 + self.match(RelayParser.T__3) + self.state = 259 + self.shapeSeq() + self.state = 260 + self.match(RelayParser.T__2) + self.state = 261 + self.type_() + self.state = 262 + self.match(RelayParser.T__4) + pass + + elif la_ == 6: + localctx = RelayParser.FuncTypeContext(self, localctx) + self.enterOuterAlt(localctx, 6) + self.state = 264 + self.match(RelayParser.T__12) + self.state = 266 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 265 + self.typeParamSeq() + + + self.state = 268 + self.match(RelayParser.T__0) + self.state = 277 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__12) | (1 << RelayParser.T__16) | (1 << RelayParser.T__17) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 269 + self.type_() + self.state = 274 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 270 + self.match(RelayParser.T__2) + self.state = 271 + self.type_() + self.state = 276 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 279 + self.match(RelayParser.T__1) + self.state = 280 + self.match(RelayParser.T__13) + self.state = 281 + self.type_() + pass + + elif la_ == 7: + localctx = RelayParser.IncompleteTypeContext(self, localctx) + self.enterOuterAlt(localctx, 7) + self.state = 282 + self.match(RelayParser.T__17) + pass + + elif la_ == 8: + localctx = RelayParser.IntTypeContext(self, localctx) + self.enterOuterAlt(localctx, 8) + self.state = 283 + self.match(RelayParser.NAT) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeSeqContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ShapeSeqContext, self).__init__(parent, invokingState) + self.parser = parser + + def shape(self, i=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ShapeContext) + else: + return self.getTypedRuleContext(RelayParser.ShapeContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_shapeSeq + + def accept(self, visitor): + if hasattr(visitor, "visitShapeSeq"): + return visitor.visitShapeSeq(self) + else: + return visitor.visitChildren(self) + + + + + def shapeSeq(self): + + localctx = RelayParser.ShapeSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_shapeSeq) + self._la = 0 # Token type + try: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 286 + self.match(RelayParser.T__0) + self.state = 287 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 288 + self.match(RelayParser.T__0) + self.state = 289 + self.shape() + self.state = 290 + self.match(RelayParser.T__2) + self.state = 291 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 293 + self.match(RelayParser.T__0) + self.state = 294 + self.shape() + self.state = 297 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 295 + self.match(RelayParser.T__2) + self.state = 296 + self.shape() + self.state = 299 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 301 + self.match(RelayParser.T__1) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ShapeContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_shape + + + def copyFrom(self, ctx): + super(RelayParser.ShapeContext, self).copyFrom(ctx) + + + + class ParensShapeContext(ShapeContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ShapeContext) + super(RelayParser.ParensShapeContext, self).__init__(parser) + self.copyFrom(ctx) + + def shape(self): + return self.getTypedRuleContext(RelayParser.ShapeContext,0) + + + def accept(self, visitor): + if hasattr(visitor, "visitParensShape"): + return visitor.visitParensShape(self) + else: + return visitor.visitChildren(self) + + + class IntShapeContext(ShapeContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ShapeContext) + super(RelayParser.IntShapeContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitIntShape"): + return visitor.visitIntShape(self) + else: + return visitor.visitChildren(self) + + + + def shape(self): + + localctx = RelayParser.ShapeContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_shape) + try: + self.state = 310 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.T__0]: + localctx = RelayParser.ParensShapeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 305 + self.match(RelayParser.T__0) + self.state = 306 + self.shape() + self.state = 307 + self.match(RelayParser.T__1) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.IntShapeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 309 + self.match(RelayParser.NAT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeIdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.TypeIdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_typeIdent + + def accept(self, visitor): + if hasattr(visitor, "visitTypeIdent"): + return visitor.visitTypeIdent(self) + else: + return visitor.visitChildren(self) + + + + + def typeIdent(self): + + localctx = RelayParser.TypeIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_typeIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 312 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BodyContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.BodyContext, self).__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_body + + def accept(self, visitor): + if hasattr(visitor, "visitBody"): + return visitor.visitBody(self) + else: + return visitor.visitChildren(self) + + + + + def body(self): + + localctx = RelayParser.BodyContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_body) + try: + self.enterOuterAlt(localctx, 1) + self.state = 314 + self.match(RelayParser.T__10) + self.state = 315 + self.expr(0) + self.state = 316 + self.match(RelayParser.T__11) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ScalarContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.ScalarContext, self).__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_scalar + + + def copyFrom(self, ctx): + super(RelayParser.ScalarContext, self).copyFrom(ctx) + + + + class ScalarFloatContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarFloatContext, self).__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(RelayParser.FLOAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarFloat"): + return visitor.visitScalarFloat(self) + else: + return visitor.visitChildren(self) + + + class ScalarBoolContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarBoolContext, self).__init__(parser) + self.copyFrom(ctx) + + def BOOL_LIT(self): + return self.getToken(RelayParser.BOOL_LIT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarBool"): + return visitor.visitScalarBool(self) + else: + return visitor.visitChildren(self) + + + class ScalarIntContext(ScalarContext): + + def __init__(self, parser, ctx): # actually a RelayParser.ScalarContext) + super(RelayParser.ScalarIntContext, self).__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor): + if hasattr(visitor, "visitScalarInt"): + return visitor.visitScalarInt(self) + else: + return visitor.visitChildren(self) + + + + def scalar(self): + + localctx = RelayParser.ScalarContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_scalar) + try: + self.state = 321 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.FLOAT]: + localctx = RelayParser.ScalarFloatContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.match(RelayParser.FLOAT) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.ScalarIntContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 319 + self.match(RelayParser.NAT) + pass + elif token in [RelayParser.BOOL_LIT]: + localctx = RelayParser.ScalarBoolContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 320 + self.match(RelayParser.BOOL_LIT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IdentContext(ParserRuleContext): + + def __init__(self, parser, parent=None, invokingState=-1): + super(RelayParser.IdentContext, self).__init__(parent, invokingState) + self.parser = parser + + def opIdent(self): + return self.getTypedRuleContext(RelayParser.OpIdentContext,0) + + + def GLOBAL_VAR(self): + return self.getToken(RelayParser.GLOBAL_VAR, 0) + + def LOCAL_VAR(self): + return self.getToken(RelayParser.LOCAL_VAR, 0) + + def GRAPH_VAR(self): + return self.getToken(RelayParser.GRAPH_VAR, 0) + + def getRuleIndex(self): + return RelayParser.RULE_ident + + def accept(self, visitor): + if hasattr(visitor, "visitIdent"): + return visitor.visitIdent(self) + else: + return visitor.visitChildren(self) + + + + + def ident(self): + + localctx = RelayParser.IdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_ident) + try: + self.state = 327 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.CNAME]: + self.enterOuterAlt(localctx, 1) + self.state = 323 + self.opIdent() + pass + elif token in [RelayParser.GLOBAL_VAR]: + self.enterOuterAlt(localctx, 2) + self.state = 324 + self.match(RelayParser.GLOBAL_VAR) + pass + elif token in [RelayParser.LOCAL_VAR]: + self.enterOuterAlt(localctx, 3) + self.state = 325 + self.match(RelayParser.LOCAL_VAR) + pass + elif token in [RelayParser.GRAPH_VAR]: + self.enterOuterAlt(localctx, 4) + self.state = 326 + self.match(RelayParser.GRAPH_VAR) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx, ruleIndex, predIndex): + if self._predicates == None: + self._predicates = dict() + self._predicates[2] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx, predIndex): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 13) + + + if predIndex == 4: + return self.precpred(self._ctx, 4) + + + if predIndex == 5: + return self.precpred(self._ctx, 18) + + + + + diff --git a/python/tvm/relay/grammar/py2/RelayVisitor.py b/python/tvm/relay/grammar/py2/RelayVisitor.py new file mode 100644 index 000000000000..eae67d8cff58 --- /dev/null +++ b/python/tvm/relay/grammar/py2/RelayVisitor.py @@ -0,0 +1,192 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * + +# This class defines a complete generic visitor for a parse tree produced by RelayParser. + +class RelayVisitor(ParseTreeVisitor): + + # Visit a parse tree produced by RelayParser#opIdent. + def visitOpIdent(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#prog. + def visitProg(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#identExpr. + def visitIdentExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#call. + def visitCall(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#neg. + def visitNeg(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tuple. + def visitTuple(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parens. + def visitParens(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcExpr. + def visitFuncExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarExpr. + def visitScalarExpr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#let. + def visitLet(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensor. + def visitTensor(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ifElse. + def visitIfElse(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#graph. + def visitGraph(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#binOp. + def visitBinOp(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#func. + def visitFunc(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#defn. + def visitDefn(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#argList. + def visitArgList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#varList. + def visitVarList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#var. + def visitVar(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attrList. + def visitAttrList(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attr. + def visitAttr(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeParamSeq. + def visitTypeParamSeq(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tupleType. + def visitTupleType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdentType. + def visitTypeIdentType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensorType. + def visitTensorType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcType. + def visitFuncType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#incompleteType. + def visitIncompleteType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intType. + def visitIntType(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#shapeSeq. + def visitShapeSeq(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parensShape. + def visitParensShape(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intShape. + def visitIntShape(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdent. + def visitTypeIdent(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#body. + def visitBody(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarFloat. + def visitScalarFloat(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarInt. + def visitScalarInt(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarBool. + def visitScalarBool(self, ctx): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ident. + def visitIdent(self, ctx): + return self.visitChildren(ctx) + + diff --git a/python/tvm/relay/grammar/py3/.gitattributes b/python/tvm/relay/grammar/py3/.gitattributes new file mode 100644 index 000000000000..4adf65fa2f3c --- /dev/null +++ b/python/tvm/relay/grammar/py3/.gitattributes @@ -0,0 +1,3 @@ +Relay* binary +Relay* linguist-generated=true +Relay* linguist-detectable=false \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/.gitignore b/python/tvm/relay/grammar/py3/.gitignore deleted file mode 100644 index d677ff551940..000000000000 --- a/python/tvm/relay/grammar/py3/.gitignore +++ /dev/null @@ -1 +0,0 @@ -Relay* diff --git a/python/tvm/relay/grammar/py3/Relay.interp b/python/tvm/relay/grammar/py3/Relay.interp new file mode 100644 index 000000000000..c6893d096168 --- /dev/null +++ b/python/tvm/relay/grammar/py3/Relay.interp @@ -0,0 +1,109 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +opIdent +prog +expr +func +defn +argList +varList +var +attrList +attr +typeParamSeq +type_ +shapeSeq +shape +typeIdent +body +scalar +ident + + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 42, 332, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 3, 2, 3, 2, 3, 3, 3, 3, 7, 3, 43, 10, 3, 12, 3, 14, 3, 46, 11, 3, 3, 3, 5, 3, 49, 10, 3, 3, 3, 3, 3, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 6, 4, 72, 10, 4, 13, 4, 14, 4, 73, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 82, 10, 4, 12, 4, 14, 4, 85, 11, 4, 5, 4, 87, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 100, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 110, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 5, 4, 128, 10, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 3, 4, 7, 4, 150, 10, 4, 12, 4, 14, 4, 153, 11, 4, 5, 4, 155, 10, 4, 3, 4, 7, 4, 158, 10, 4, 12, 4, 14, 4, 161, 11, 4, 3, 5, 3, 5, 5, 5, 165, 10, 5, 3, 5, 3, 5, 3, 5, 3, 5, 3, 5, 5, 5, 172, 10, 5, 3, 5, 3, 5, 3, 6, 3, 6, 3, 6, 5, 6, 179, 10, 6, 3, 6, 3, 6, 3, 6, 3, 6, 3, 6, 5, 6, 186, 10, 6, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 3, 7, 5, 7, 196, 10, 7, 3, 8, 3, 8, 3, 8, 7, 8, 201, 10, 8, 12, 8, 14, 8, 204, 11, 8, 5, 8, 206, 10, 8, 3, 9, 3, 9, 3, 9, 5, 9, 211, 10, 9, 3, 10, 3, 10, 3, 10, 7, 10, 216, 10, 10, 12, 10, 14, 10, 219, 11, 10, 5, 10, 221, 10, 10, 3, 11, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 3, 12, 7, 12, 233, 10, 12, 12, 12, 14, 12, 236, 11, 12, 3, 12, 3, 12, 5, 12, 240, 10, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 6, 13, 253, 10, 13, 13, 13, 14, 13, 254, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 269, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 7, 13, 275, 10, 13, 12, 13, 14, 13, 278, 11, 13, 5, 13, 280, 10, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 287, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 6, 14, 300, 10, 14, 13, 14, 14, 14, 301, 3, 14, 3, 14, 5, 14, 306, 10, 14, 3, 15, 3, 15, 3, 15, 3, 15, 3, 15, 5, 15, 313, 10, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 5, 18, 324, 10, 18, 3, 19, 3, 19, 3, 19, 3, 19, 5, 19, 330, 10, 19, 3, 19, 2, 3, 6, 20, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 2, 6, 3, 2, 25, 26, 3, 2, 27, 28, 3, 2, 29, 32, 3, 2, 33, 34, 2, 373, 2, 38, 3, 2, 2, 2, 4, 40, 3, 2, 2, 2, 6, 127, 3, 2, 2, 2, 8, 162, 3, 2, 2, 2, 10, 175, 3, 2, 2, 2, 12, 195, 3, 2, 2, 2, 14, 205, 3, 2, 2, 2, 16, 207, 3, 2, 2, 2, 18, 220, 3, 2, 2, 2, 20, 222, 3, 2, 2, 2, 22, 239, 3, 2, 2, 2, 24, 286, 3, 2, 2, 2, 26, 305, 3, 2, 2, 2, 28, 312, 3, 2, 2, 2, 30, 314, 3, 2, 2, 2, 32, 316, 3, 2, 2, 2, 34, 323, 3, 2, 2, 2, 36, 329, 3, 2, 2, 2, 38, 39, 7, 42, 2, 2, 39, 3, 3, 2, 2, 2, 40, 48, 7, 21, 2, 2, 41, 43, 5, 10, 6, 2, 42, 41, 3, 2, 2, 2, 43, 46, 3, 2, 2, 2, 44, 42, 3, 2, 2, 2, 44, 45, 3, 2, 2, 2, 45, 49, 3, 2, 2, 2, 46, 44, 3, 2, 2, 2, 47, 49, 5, 6, 4, 2, 48, 44, 3, 2, 2, 2, 48, 47, 3, 2, 2, 2, 49, 50, 3, 2, 2, 2, 50, 51, 7, 2, 2, 3, 51, 5, 3, 2, 2, 2, 52, 53, 8, 4, 1, 2, 53, 54, 7, 3, 2, 2, 54, 55, 5, 6, 4, 2, 55, 56, 7, 4, 2, 2, 56, 128, 3, 2, 2, 2, 57, 58, 7, 28, 2, 2, 58, 128, 5, 6, 4, 19, 59, 128, 5, 8, 5, 2, 60, 61, 7, 3, 2, 2, 61, 128, 7, 4, 2, 2, 62, 63, 7, 3, 2, 2, 63, 64, 5, 6, 4, 2, 64, 65, 7, 5, 2, 2, 65, 66, 7, 4, 2, 2, 66, 128, 3, 2, 2, 2, 67, 68, 7, 3, 2, 2, 68, 71, 5, 6, 4, 2, 69, 70, 7, 5, 2, 2, 70, 72, 5, 6, 4, 2, 71, 69, 3, 2, 2, 2, 72, 73, 3, 2, 2, 2, 73, 71, 3, 2, 2, 2, 73, 74, 3, 2, 2, 2, 74, 75, 3, 2, 2, 2, 75, 76, 7, 4, 2, 2, 76, 128, 3, 2, 2, 2, 77, 86, 7, 6, 2, 2, 78, 83, 5, 6, 4, 2, 79, 80, 7, 5, 2, 2, 80, 82, 5, 6, 4, 2, 81, 79, 3, 2, 2, 2, 82, 85, 3, 2, 2, 2, 83, 81, 3, 2, 2, 2, 83, 84, 3, 2, 2, 2, 84, 87, 3, 2, 2, 2, 85, 83, 3, 2, 2, 2, 86, 78, 3, 2, 2, 2, 86, 87, 3, 2, 2, 2, 87, 88, 3, 2, 2, 2, 88, 128, 7, 7, 2, 2, 89, 90, 7, 8, 2, 2, 90, 91, 7, 3, 2, 2, 91, 92, 5, 6, 4, 2, 92, 93, 7, 4, 2, 2, 93, 94, 5, 32, 17, 2, 94, 95, 7, 9, 2, 2, 95, 96, 5, 32, 17, 2, 96, 128, 3, 2, 2, 2, 97, 99, 7, 10, 2, 2, 98, 100, 7, 38, 2, 2, 99, 98, 3, 2, 2, 2, 99, 100, 3, 2, 2, 2, 100, 101, 3, 2, 2, 2, 101, 102, 5, 16, 9, 2, 102, 103, 7, 11, 2, 2, 103, 104, 5, 6, 4, 2, 104, 105, 7, 12, 2, 2, 105, 106, 5, 6, 4, 8, 106, 128, 3, 2, 2, 2, 107, 109, 7, 10, 2, 2, 108, 110, 7, 38, 2, 2, 109, 108, 3, 2, 2, 2, 109, 110, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 112, 5, 16, 9, 2, 112, 113, 7, 11, 2, 2, 113, 114, 7, 13, 2, 2, 114, 115, 5, 6, 4, 2, 115, 116, 7, 14, 2, 2, 116, 117, 7, 12, 2, 2, 117, 118, 5, 6, 4, 7, 118, 128, 3, 2, 2, 2, 119, 120, 5, 36, 19, 2, 120, 121, 7, 11, 2, 2, 121, 122, 5, 6, 4, 2, 122, 123, 7, 12, 2, 2, 123, 124, 5, 6, 4, 5, 124, 128, 3, 2, 2, 2, 125, 128, 5, 36, 19, 2, 126, 128, 5, 34, 18, 2, 127, 52, 3, 2, 2, 2, 127, 57, 3, 2, 2, 2, 127, 59, 3, 2, 2, 2, 127, 60, 3, 2, 2, 2, 127, 62, 3, 2, 2, 2, 127, 67, 3, 2, 2, 2, 127, 77, 3, 2, 2, 2, 127, 89, 3, 2, 2, 2, 127, 97, 3, 2, 2, 2, 127, 107, 3, 2, 2, 2, 127, 119, 3, 2, 2, 2, 127, 125, 3, 2, 2, 2, 127, 126, 3, 2, 2, 2, 128, 159, 3, 2, 2, 2, 129, 130, 12, 18, 2, 2, 130, 131, 9, 2, 2, 2, 131, 158, 5, 6, 4, 19, 132, 133, 12, 17, 2, 2, 133, 134, 9, 3, 2, 2, 134, 158, 5, 6, 4, 18, 135, 136, 12, 16, 2, 2, 136, 137, 9, 4, 2, 2, 137, 158, 5, 6, 4, 17, 138, 139, 12, 15, 2, 2, 139, 140, 9, 5, 2, 2, 140, 158, 5, 6, 4, 16, 141, 142, 12, 6, 2, 2, 142, 143, 7, 12, 2, 2, 143, 158, 5, 6, 4, 7, 144, 145, 12, 20, 2, 2, 145, 154, 7, 3, 2, 2, 146, 151, 5, 6, 4, 2, 147, 148, 7, 5, 2, 2, 148, 150, 5, 6, 4, 2, 149, 147, 3, 2, 2, 2, 150, 153, 3, 2, 2, 2, 151, 149, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 155, 3, 2, 2, 2, 153, 151, 3, 2, 2, 2, 154, 146, 3, 2, 2, 2, 154, 155, 3, 2, 2, 2, 155, 156, 3, 2, 2, 2, 156, 158, 7, 4, 2, 2, 157, 129, 3, 2, 2, 2, 157, 132, 3, 2, 2, 2, 157, 135, 3, 2, 2, 2, 157, 138, 3, 2, 2, 2, 157, 141, 3, 2, 2, 2, 157, 144, 3, 2, 2, 2, 158, 161, 3, 2, 2, 2, 159, 157, 3, 2, 2, 2, 159, 160, 3, 2, 2, 2, 160, 7, 3, 2, 2, 2, 161, 159, 3, 2, 2, 2, 162, 164, 7, 15, 2, 2, 163, 165, 5, 22, 12, 2, 164, 163, 3, 2, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 3, 2, 2, 2, 166, 167, 7, 3, 2, 2, 167, 168, 5, 12, 7, 2, 168, 171, 7, 4, 2, 2, 169, 170, 7, 16, 2, 2, 170, 172, 5, 24, 13, 2, 171, 169, 3, 2, 2, 2, 171, 172, 3, 2, 2, 2, 172, 173, 3, 2, 2, 2, 173, 174, 5, 32, 17, 2, 174, 9, 3, 2, 2, 2, 175, 176, 7, 17, 2, 2, 176, 178, 5, 36, 19, 2, 177, 179, 5, 22, 12, 2, 178, 177, 3, 2, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 3, 2, 2, 2, 180, 181, 7, 3, 2, 2, 181, 182, 5, 12, 7, 2, 182, 185, 7, 4, 2, 2, 183, 184, 7, 16, 2, 2, 184, 186, 5, 24, 13, 2, 185, 183, 3, 2, 2, 2, 185, 186, 3, 2, 2, 2, 186, 187, 3, 2, 2, 2, 187, 188, 5, 32, 17, 2, 188, 11, 3, 2, 2, 2, 189, 196, 5, 14, 8, 2, 190, 196, 5, 18, 10, 2, 191, 192, 5, 14, 8, 2, 192, 193, 7, 5, 2, 2, 193, 194, 5, 18, 10, 2, 194, 196, 3, 2, 2, 2, 195, 189, 3, 2, 2, 2, 195, 190, 3, 2, 2, 2, 195, 191, 3, 2, 2, 2, 196, 13, 3, 2, 2, 2, 197, 202, 5, 16, 9, 2, 198, 199, 7, 5, 2, 2, 199, 201, 5, 16, 9, 2, 200, 198, 3, 2, 2, 2, 201, 204, 3, 2, 2, 2, 202, 200, 3, 2, 2, 2, 202, 203, 3, 2, 2, 2, 203, 206, 3, 2, 2, 2, 204, 202, 3, 2, 2, 2, 205, 197, 3, 2, 2, 2, 205, 206, 3, 2, 2, 2, 206, 15, 3, 2, 2, 2, 207, 210, 5, 36, 19, 2, 208, 209, 7, 18, 2, 2, 209, 211, 5, 24, 13, 2, 210, 208, 3, 2, 2, 2, 210, 211, 3, 2, 2, 2, 211, 17, 3, 2, 2, 2, 212, 217, 5, 20, 11, 2, 213, 214, 7, 5, 2, 2, 214, 216, 5, 20, 11, 2, 215, 213, 3, 2, 2, 2, 216, 219, 3, 2, 2, 2, 217, 215, 3, 2, 2, 2, 217, 218, 3, 2, 2, 2, 218, 221, 3, 2, 2, 2, 219, 217, 3, 2, 2, 2, 220, 212, 3, 2, 2, 2, 220, 221, 3, 2, 2, 2, 221, 19, 3, 2, 2, 2, 222, 223, 7, 42, 2, 2, 223, 224, 7, 11, 2, 2, 224, 225, 5, 6, 4, 2, 225, 21, 3, 2, 2, 2, 226, 227, 7, 6, 2, 2, 227, 240, 7, 7, 2, 2, 228, 229, 7, 6, 2, 2, 229, 234, 5, 36, 19, 2, 230, 231, 7, 5, 2, 2, 231, 233, 5, 36, 19, 2, 232, 230, 3, 2, 2, 2, 233, 236, 3, 2, 2, 2, 234, 232, 3, 2, 2, 2, 234, 235, 3, 2, 2, 2, 235, 237, 3, 2, 2, 2, 236, 234, 3, 2, 2, 2, 237, 238, 7, 7, 2, 2, 238, 240, 3, 2, 2, 2, 239, 226, 3, 2, 2, 2, 239, 228, 3, 2, 2, 2, 240, 23, 3, 2, 2, 2, 241, 242, 7, 3, 2, 2, 242, 287, 7, 4, 2, 2, 243, 244, 7, 3, 2, 2, 244, 245, 5, 24, 13, 2, 245, 246, 7, 5, 2, 2, 246, 247, 7, 4, 2, 2, 247, 287, 3, 2, 2, 2, 248, 249, 7, 3, 2, 2, 249, 252, 5, 24, 13, 2, 250, 251, 7, 5, 2, 2, 251, 253, 5, 24, 13, 2, 252, 250, 3, 2, 2, 2, 253, 254, 3, 2, 2, 2, 254, 252, 3, 2, 2, 2, 254, 255, 3, 2, 2, 2, 255, 256, 3, 2, 2, 2, 256, 257, 7, 4, 2, 2, 257, 287, 3, 2, 2, 2, 258, 287, 5, 30, 16, 2, 259, 260, 7, 19, 2, 2, 260, 261, 7, 6, 2, 2, 261, 262, 5, 26, 14, 2, 262, 263, 7, 5, 2, 2, 263, 264, 5, 24, 13, 2, 264, 265, 7, 7, 2, 2, 265, 287, 3, 2, 2, 2, 266, 268, 7, 15, 2, 2, 267, 269, 5, 22, 12, 2, 268, 267, 3, 2, 2, 2, 268, 269, 3, 2, 2, 2, 269, 270, 3, 2, 2, 2, 270, 279, 7, 3, 2, 2, 271, 276, 5, 24, 13, 2, 272, 273, 7, 5, 2, 2, 273, 275, 5, 24, 13, 2, 274, 272, 3, 2, 2, 2, 275, 278, 3, 2, 2, 2, 276, 274, 3, 2, 2, 2, 276, 277, 3, 2, 2, 2, 277, 280, 3, 2, 2, 2, 278, 276, 3, 2, 2, 2, 279, 271, 3, 2, 2, 2, 279, 280, 3, 2, 2, 2, 280, 281, 3, 2, 2, 2, 281, 282, 7, 4, 2, 2, 282, 283, 7, 16, 2, 2, 283, 287, 5, 24, 13, 2, 284, 287, 7, 20, 2, 2, 285, 287, 7, 41, 2, 2, 286, 241, 3, 2, 2, 2, 286, 243, 3, 2, 2, 2, 286, 248, 3, 2, 2, 2, 286, 258, 3, 2, 2, 2, 286, 259, 3, 2, 2, 2, 286, 266, 3, 2, 2, 2, 286, 284, 3, 2, 2, 2, 286, 285, 3, 2, 2, 2, 287, 25, 3, 2, 2, 2, 288, 289, 7, 3, 2, 2, 289, 306, 7, 4, 2, 2, 290, 291, 7, 3, 2, 2, 291, 292, 5, 28, 15, 2, 292, 293, 7, 5, 2, 2, 293, 294, 7, 4, 2, 2, 294, 306, 3, 2, 2, 2, 295, 296, 7, 3, 2, 2, 296, 299, 5, 28, 15, 2, 297, 298, 7, 5, 2, 2, 298, 300, 5, 28, 15, 2, 299, 297, 3, 2, 2, 2, 300, 301, 3, 2, 2, 2, 301, 299, 3, 2, 2, 2, 301, 302, 3, 2, 2, 2, 302, 303, 3, 2, 2, 2, 303, 304, 7, 4, 2, 2, 304, 306, 3, 2, 2, 2, 305, 288, 3, 2, 2, 2, 305, 290, 3, 2, 2, 2, 305, 295, 3, 2, 2, 2, 306, 27, 3, 2, 2, 2, 307, 308, 7, 3, 2, 2, 308, 309, 5, 28, 15, 2, 309, 310, 7, 4, 2, 2, 310, 313, 3, 2, 2, 2, 311, 313, 7, 41, 2, 2, 312, 307, 3, 2, 2, 2, 312, 311, 3, 2, 2, 2, 313, 29, 3, 2, 2, 2, 314, 315, 7, 42, 2, 2, 315, 31, 3, 2, 2, 2, 316, 317, 7, 13, 2, 2, 317, 318, 5, 6, 4, 2, 318, 319, 7, 14, 2, 2, 319, 33, 3, 2, 2, 2, 320, 324, 7, 40, 2, 2, 321, 324, 7, 41, 2, 2, 322, 324, 7, 39, 2, 2, 323, 320, 3, 2, 2, 2, 323, 321, 3, 2, 2, 2, 323, 322, 3, 2, 2, 2, 324, 35, 3, 2, 2, 2, 325, 330, 5, 2, 2, 2, 326, 330, 7, 35, 2, 2, 327, 330, 7, 36, 2, 2, 328, 330, 7, 37, 2, 2, 329, 325, 3, 2, 2, 2, 329, 326, 3, 2, 2, 2, 329, 327, 3, 2, 2, 2, 329, 328, 3, 2, 2, 2, 330, 37, 3, 2, 2, 2, 36, 44, 48, 73, 83, 86, 99, 109, 127, 151, 154, 157, 159, 164, 171, 178, 185, 195, 202, 205, 210, 217, 220, 234, 239, 254, 268, 276, 279, 286, 301, 305, 312, 323, 329] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/Relay.tokens b/python/tvm/relay/grammar/py3/Relay.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py3/Relay.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py3/RelayLexer.interp b/python/tvm/relay/grammar/py3/RelayLexer.interp new file mode 100644 index 000000000000..092b3589ab70 --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.interp @@ -0,0 +1,140 @@ +token literal names: +null +'(' +')' +',' +'[' +']' +'if' +'else' +'let' +'=' +';' +'{' +'}' +'fn' +'->' +'def' +':' +'Tensor' +'_' +'v0.0.2' +null +null +null +'*' +'/' +'+' +'-' +'<' +'>' +'<=' +'>=' +'==' +'!=' +null +null +null +'mut' +null +null +null +null + +token symbolic names: +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +null +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +CNAME + +rule names: +T__0 +T__1 +T__2 +T__3 +T__4 +T__5 +T__6 +T__7 +T__8 +T__9 +T__10 +T__11 +T__12 +T__13 +T__14 +T__15 +T__16 +T__17 +SEMVER +WS +LINE_COMMENT +COMMENT +MUL +DIV +ADD +SUB +LT +GT +LE +GE +EQ +NE +GLOBAL_VAR +LOCAL_VAR +GRAPH_VAR +MUT +BOOL_LIT +FLOAT +NAT +EXP +CNAME +LETTER +DIGIT + +channel names: +DEFAULT_TOKEN_CHANNEL +HIDDEN + +mode names: +DEFAULT_MODE + +atn: +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 42, 267, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 6, 21, 149, 10, 21, 13, 21, 14, 21, 150, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 22, 7, 22, 159, 10, 22, 12, 22, 14, 22, 162, 11, 22, 3, 22, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 23, 3, 23, 7, 23, 172, 10, 23, 12, 23, 14, 23, 175, 11, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 34, 3, 34, 3, 34, 3, 35, 3, 35, 3, 35, 3, 36, 3, 36, 3, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 228, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 234, 10, 39, 3, 39, 3, 39, 3, 39, 5, 39, 239, 10, 39, 3, 40, 6, 40, 242, 10, 40, 13, 40, 14, 40, 243, 3, 41, 3, 41, 5, 41, 248, 10, 41, 3, 41, 3, 41, 3, 42, 3, 42, 5, 42, 254, 10, 42, 3, 42, 3, 42, 3, 42, 7, 42, 259, 10, 42, 12, 42, 14, 42, 262, 11, 42, 3, 43, 3, 43, 3, 44, 3, 44, 4, 160, 173, 2, 45, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 2, 83, 42, 85, 2, 87, 2, 3, 2, 7, 5, 2, 11, 12, 15, 15, 34, 34, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 67, 92, 99, 124, 3, 2, 50, 59, 2, 275, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 3, 89, 3, 2, 2, 2, 5, 91, 3, 2, 2, 2, 7, 93, 3, 2, 2, 2, 9, 95, 3, 2, 2, 2, 11, 97, 3, 2, 2, 2, 13, 99, 3, 2, 2, 2, 15, 102, 3, 2, 2, 2, 17, 107, 3, 2, 2, 2, 19, 111, 3, 2, 2, 2, 21, 113, 3, 2, 2, 2, 23, 115, 3, 2, 2, 2, 25, 117, 3, 2, 2, 2, 27, 119, 3, 2, 2, 2, 29, 122, 3, 2, 2, 2, 31, 125, 3, 2, 2, 2, 33, 129, 3, 2, 2, 2, 35, 131, 3, 2, 2, 2, 37, 138, 3, 2, 2, 2, 39, 140, 3, 2, 2, 2, 41, 148, 3, 2, 2, 2, 43, 154, 3, 2, 2, 2, 45, 167, 3, 2, 2, 2, 47, 181, 3, 2, 2, 2, 49, 183, 3, 2, 2, 2, 51, 185, 3, 2, 2, 2, 53, 187, 3, 2, 2, 2, 55, 189, 3, 2, 2, 2, 57, 191, 3, 2, 2, 2, 59, 193, 3, 2, 2, 2, 61, 196, 3, 2, 2, 2, 63, 199, 3, 2, 2, 2, 65, 202, 3, 2, 2, 2, 67, 205, 3, 2, 2, 2, 69, 208, 3, 2, 2, 2, 71, 211, 3, 2, 2, 2, 73, 214, 3, 2, 2, 2, 75, 227, 3, 2, 2, 2, 77, 238, 3, 2, 2, 2, 79, 241, 3, 2, 2, 2, 81, 245, 3, 2, 2, 2, 83, 253, 3, 2, 2, 2, 85, 263, 3, 2, 2, 2, 87, 265, 3, 2, 2, 2, 89, 90, 7, 42, 2, 2, 90, 4, 3, 2, 2, 2, 91, 92, 7, 43, 2, 2, 92, 6, 3, 2, 2, 2, 93, 94, 7, 46, 2, 2, 94, 8, 3, 2, 2, 2, 95, 96, 7, 93, 2, 2, 96, 10, 3, 2, 2, 2, 97, 98, 7, 95, 2, 2, 98, 12, 3, 2, 2, 2, 99, 100, 7, 107, 2, 2, 100, 101, 7, 104, 2, 2, 101, 14, 3, 2, 2, 2, 102, 103, 7, 103, 2, 2, 103, 104, 7, 110, 2, 2, 104, 105, 7, 117, 2, 2, 105, 106, 7, 103, 2, 2, 106, 16, 3, 2, 2, 2, 107, 108, 7, 110, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 18, 3, 2, 2, 2, 111, 112, 7, 63, 2, 2, 112, 20, 3, 2, 2, 2, 113, 114, 7, 61, 2, 2, 114, 22, 3, 2, 2, 2, 115, 116, 7, 125, 2, 2, 116, 24, 3, 2, 2, 2, 117, 118, 7, 127, 2, 2, 118, 26, 3, 2, 2, 2, 119, 120, 7, 104, 2, 2, 120, 121, 7, 112, 2, 2, 121, 28, 3, 2, 2, 2, 122, 123, 7, 47, 2, 2, 123, 124, 7, 64, 2, 2, 124, 30, 3, 2, 2, 2, 125, 126, 7, 102, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 104, 2, 2, 128, 32, 3, 2, 2, 2, 129, 130, 7, 60, 2, 2, 130, 34, 3, 2, 2, 2, 131, 132, 7, 86, 2, 2, 132, 133, 7, 103, 2, 2, 133, 134, 7, 112, 2, 2, 134, 135, 7, 117, 2, 2, 135, 136, 7, 113, 2, 2, 136, 137, 7, 116, 2, 2, 137, 36, 3, 2, 2, 2, 138, 139, 7, 97, 2, 2, 139, 38, 3, 2, 2, 2, 140, 141, 7, 120, 2, 2, 141, 142, 7, 50, 2, 2, 142, 143, 7, 48, 2, 2, 143, 144, 7, 50, 2, 2, 144, 145, 7, 48, 2, 2, 145, 146, 7, 52, 2, 2, 146, 40, 3, 2, 2, 2, 147, 149, 9, 2, 2, 2, 148, 147, 3, 2, 2, 2, 149, 150, 3, 2, 2, 2, 150, 148, 3, 2, 2, 2, 150, 151, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 153, 8, 21, 2, 2, 153, 42, 3, 2, 2, 2, 154, 155, 7, 49, 2, 2, 155, 156, 7, 49, 2, 2, 156, 160, 3, 2, 2, 2, 157, 159, 11, 2, 2, 2, 158, 157, 3, 2, 2, 2, 159, 162, 3, 2, 2, 2, 160, 161, 3, 2, 2, 2, 160, 158, 3, 2, 2, 2, 161, 163, 3, 2, 2, 2, 162, 160, 3, 2, 2, 2, 163, 164, 7, 12, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 8, 22, 2, 2, 166, 44, 3, 2, 2, 2, 167, 168, 7, 49, 2, 2, 168, 169, 7, 44, 2, 2, 169, 173, 3, 2, 2, 2, 170, 172, 11, 2, 2, 2, 171, 170, 3, 2, 2, 2, 172, 175, 3, 2, 2, 2, 173, 174, 3, 2, 2, 2, 173, 171, 3, 2, 2, 2, 174, 176, 3, 2, 2, 2, 175, 173, 3, 2, 2, 2, 176, 177, 7, 44, 2, 2, 177, 178, 7, 49, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 8, 23, 2, 2, 180, 46, 3, 2, 2, 2, 181, 182, 7, 44, 2, 2, 182, 48, 3, 2, 2, 2, 183, 184, 7, 49, 2, 2, 184, 50, 3, 2, 2, 2, 185, 186, 7, 45, 2, 2, 186, 52, 3, 2, 2, 2, 187, 188, 7, 47, 2, 2, 188, 54, 3, 2, 2, 2, 189, 190, 7, 62, 2, 2, 190, 56, 3, 2, 2, 2, 191, 192, 7, 64, 2, 2, 192, 58, 3, 2, 2, 2, 193, 194, 7, 62, 2, 2, 194, 195, 7, 63, 2, 2, 195, 60, 3, 2, 2, 2, 196, 197, 7, 64, 2, 2, 197, 198, 7, 63, 2, 2, 198, 62, 3, 2, 2, 2, 199, 200, 7, 63, 2, 2, 200, 201, 7, 63, 2, 2, 201, 64, 3, 2, 2, 2, 202, 203, 7, 35, 2, 2, 203, 204, 7, 63, 2, 2, 204, 66, 3, 2, 2, 2, 205, 206, 7, 66, 2, 2, 206, 207, 5, 83, 42, 2, 207, 68, 3, 2, 2, 2, 208, 209, 7, 39, 2, 2, 209, 210, 5, 83, 42, 2, 210, 70, 3, 2, 2, 2, 211, 212, 7, 39, 2, 2, 212, 213, 5, 79, 40, 2, 213, 72, 3, 2, 2, 2, 214, 215, 7, 111, 2, 2, 215, 216, 7, 119, 2, 2, 216, 217, 7, 118, 2, 2, 217, 74, 3, 2, 2, 2, 218, 219, 7, 86, 2, 2, 219, 220, 7, 116, 2, 2, 220, 221, 7, 119, 2, 2, 221, 228, 7, 103, 2, 2, 222, 223, 7, 72, 2, 2, 223, 224, 7, 99, 2, 2, 224, 225, 7, 110, 2, 2, 225, 226, 7, 117, 2, 2, 226, 228, 7, 103, 2, 2, 227, 218, 3, 2, 2, 2, 227, 222, 3, 2, 2, 2, 228, 76, 3, 2, 2, 2, 229, 230, 5, 79, 40, 2, 230, 231, 7, 48, 2, 2, 231, 233, 5, 79, 40, 2, 232, 234, 5, 81, 41, 2, 233, 232, 3, 2, 2, 2, 233, 234, 3, 2, 2, 2, 234, 239, 3, 2, 2, 2, 235, 236, 5, 79, 40, 2, 236, 237, 5, 81, 41, 2, 237, 239, 3, 2, 2, 2, 238, 229, 3, 2, 2, 2, 238, 235, 3, 2, 2, 2, 239, 78, 3, 2, 2, 2, 240, 242, 5, 87, 44, 2, 241, 240, 3, 2, 2, 2, 242, 243, 3, 2, 2, 2, 243, 241, 3, 2, 2, 2, 243, 244, 3, 2, 2, 2, 244, 80, 3, 2, 2, 2, 245, 247, 9, 3, 2, 2, 246, 248, 9, 4, 2, 2, 247, 246, 3, 2, 2, 2, 247, 248, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 250, 5, 79, 40, 2, 250, 82, 3, 2, 2, 2, 251, 254, 7, 97, 2, 2, 252, 254, 5, 85, 43, 2, 253, 251, 3, 2, 2, 2, 253, 252, 3, 2, 2, 2, 254, 260, 3, 2, 2, 2, 255, 259, 7, 97, 2, 2, 256, 259, 5, 85, 43, 2, 257, 259, 5, 87, 44, 2, 258, 255, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 258, 257, 3, 2, 2, 2, 259, 262, 3, 2, 2, 2, 260, 258, 3, 2, 2, 2, 260, 261, 3, 2, 2, 2, 261, 84, 3, 2, 2, 2, 262, 260, 3, 2, 2, 2, 263, 264, 9, 5, 2, 2, 264, 86, 3, 2, 2, 2, 265, 266, 9, 6, 2, 2, 266, 88, 3, 2, 2, 2, 14, 2, 150, 160, 173, 227, 233, 238, 243, 247, 253, 258, 260, 3, 8, 2, 2] \ No newline at end of file diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py new file mode 100644 index 000000000000..fbf74bf1411b --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.py @@ -0,0 +1,203 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * +from io import StringIO +from typing.io import TextIO +import sys + + + +def serializedATN(): + with StringIO() as buf: + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\2*") + buf.write("\u010b\b\1\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7") + buf.write("\t\7\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r") + buf.write("\4\16\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23") + buf.write("\t\23\4\24\t\24\4\25\t\25\4\26\t\26\4\27\t\27\4\30\t\30") + buf.write("\4\31\t\31\4\32\t\32\4\33\t\33\4\34\t\34\4\35\t\35\4\36") + buf.write("\t\36\4\37\t\37\4 \t \4!\t!\4\"\t\"\4#\t#\4$\t$\4%\t%") + buf.write("\4&\t&\4\'\t\'\4(\t(\4)\t)\4*\t*\4+\t+\4,\t,\3\2\3\2\3") + buf.write("\3\3\3\3\4\3\4\3\5\3\5\3\6\3\6\3\7\3\7\3\7\3\b\3\b\3\b") + buf.write("\3\b\3\b\3\t\3\t\3\t\3\t\3\n\3\n\3\13\3\13\3\f\3\f\3\r") + buf.write("\3\r\3\16\3\16\3\16\3\17\3\17\3\17\3\20\3\20\3\20\3\20") + buf.write("\3\21\3\21\3\22\3\22\3\22\3\22\3\22\3\22\3\22\3\23\3\23") + buf.write("\3\24\3\24\3\24\3\24\3\24\3\24\3\24\3\25\6\25\u0095\n") + buf.write("\25\r\25\16\25\u0096\3\25\3\25\3\26\3\26\3\26\3\26\7\26") + buf.write("\u009f\n\26\f\26\16\26\u00a2\13\26\3\26\3\26\3\26\3\26") + buf.write("\3\27\3\27\3\27\3\27\7\27\u00ac\n\27\f\27\16\27\u00af") + buf.write("\13\27\3\27\3\27\3\27\3\27\3\27\3\30\3\30\3\31\3\31\3") + buf.write("\32\3\32\3\33\3\33\3\34\3\34\3\35\3\35\3\36\3\36\3\36") + buf.write("\3\37\3\37\3\37\3 \3 \3 \3!\3!\3!\3\"\3\"\3\"\3#\3#\3") + buf.write("#\3$\3$\3$\3%\3%\3%\3%\3&\3&\3&\3&\3&\3&\3&\3&\3&\5&\u00e4") + buf.write("\n&\3\'\3\'\3\'\3\'\5\'\u00ea\n\'\3\'\3\'\3\'\5\'\u00ef") + buf.write("\n\'\3(\6(\u00f2\n(\r(\16(\u00f3\3)\3)\5)\u00f8\n)\3)") + buf.write("\3)\3*\3*\5*\u00fe\n*\3*\3*\3*\7*\u0103\n*\f*\16*\u0106") + buf.write("\13*\3+\3+\3,\3,\4\u00a0\u00ad\2-\3\3\5\4\7\5\t\6\13\7") + buf.write("\r\b\17\t\21\n\23\13\25\f\27\r\31\16\33\17\35\20\37\21") + buf.write("!\22#\23%\24\'\25)\26+\27-\30/\31\61\32\63\33\65\34\67") + buf.write("\359\36;\37= ?!A\"C#E$G%I&K\'M(O)Q\2S*U\2W\2\3\2\7\5\2") + buf.write("\13\f\17\17\"\"\4\2GGgg\4\2--//\4\2C\\c|\3\2\62;\2\u0113") + buf.write("\2\3\3\2\2\2\2\5\3\2\2\2\2\7\3\2\2\2\2\t\3\2\2\2\2\13") + buf.write("\3\2\2\2\2\r\3\2\2\2\2\17\3\2\2\2\2\21\3\2\2\2\2\23\3") + buf.write("\2\2\2\2\25\3\2\2\2\2\27\3\2\2\2\2\31\3\2\2\2\2\33\3\2") + buf.write("\2\2\2\35\3\2\2\2\2\37\3\2\2\2\2!\3\2\2\2\2#\3\2\2\2\2") + buf.write("%\3\2\2\2\2\'\3\2\2\2\2)\3\2\2\2\2+\3\2\2\2\2-\3\2\2\2") + buf.write("\2/\3\2\2\2\2\61\3\2\2\2\2\63\3\2\2\2\2\65\3\2\2\2\2\67") + buf.write("\3\2\2\2\29\3\2\2\2\2;\3\2\2\2\2=\3\2\2\2\2?\3\2\2\2\2") + buf.write("A\3\2\2\2\2C\3\2\2\2\2E\3\2\2\2\2G\3\2\2\2\2I\3\2\2\2") + buf.write("\2K\3\2\2\2\2M\3\2\2\2\2O\3\2\2\2\2S\3\2\2\2\3Y\3\2\2") + buf.write("\2\5[\3\2\2\2\7]\3\2\2\2\t_\3\2\2\2\13a\3\2\2\2\rc\3\2") + buf.write("\2\2\17f\3\2\2\2\21k\3\2\2\2\23o\3\2\2\2\25q\3\2\2\2\27") + buf.write("s\3\2\2\2\31u\3\2\2\2\33w\3\2\2\2\35z\3\2\2\2\37}\3\2") + buf.write("\2\2!\u0081\3\2\2\2#\u0083\3\2\2\2%\u008a\3\2\2\2\'\u008c") + buf.write("\3\2\2\2)\u0094\3\2\2\2+\u009a\3\2\2\2-\u00a7\3\2\2\2") + buf.write("/\u00b5\3\2\2\2\61\u00b7\3\2\2\2\63\u00b9\3\2\2\2\65\u00bb") + buf.write("\3\2\2\2\67\u00bd\3\2\2\29\u00bf\3\2\2\2;\u00c1\3\2\2") + buf.write("\2=\u00c4\3\2\2\2?\u00c7\3\2\2\2A\u00ca\3\2\2\2C\u00cd") + buf.write("\3\2\2\2E\u00d0\3\2\2\2G\u00d3\3\2\2\2I\u00d6\3\2\2\2") + buf.write("K\u00e3\3\2\2\2M\u00ee\3\2\2\2O\u00f1\3\2\2\2Q\u00f5\3") + buf.write("\2\2\2S\u00fd\3\2\2\2U\u0107\3\2\2\2W\u0109\3\2\2\2YZ") + buf.write("\7*\2\2Z\4\3\2\2\2[\\\7+\2\2\\\6\3\2\2\2]^\7.\2\2^\b\3") + buf.write("\2\2\2_`\7]\2\2`\n\3\2\2\2ab\7_\2\2b\f\3\2\2\2cd\7k\2") + buf.write("\2de\7h\2\2e\16\3\2\2\2fg\7g\2\2gh\7n\2\2hi\7u\2\2ij\7") + buf.write("g\2\2j\20\3\2\2\2kl\7n\2\2lm\7g\2\2mn\7v\2\2n\22\3\2\2") + buf.write("\2op\7?\2\2p\24\3\2\2\2qr\7=\2\2r\26\3\2\2\2st\7}\2\2") + buf.write("t\30\3\2\2\2uv\7\177\2\2v\32\3\2\2\2wx\7h\2\2xy\7p\2\2") + buf.write("y\34\3\2\2\2z{\7/\2\2{|\7@\2\2|\36\3\2\2\2}~\7f\2\2~\177") + buf.write("\7g\2\2\177\u0080\7h\2\2\u0080 \3\2\2\2\u0081\u0082\7") + buf.write("<\2\2\u0082\"\3\2\2\2\u0083\u0084\7V\2\2\u0084\u0085\7") + buf.write("g\2\2\u0085\u0086\7p\2\2\u0086\u0087\7u\2\2\u0087\u0088") + buf.write("\7q\2\2\u0088\u0089\7t\2\2\u0089$\3\2\2\2\u008a\u008b") + buf.write("\7a\2\2\u008b&\3\2\2\2\u008c\u008d\7x\2\2\u008d\u008e") + buf.write("\7\62\2\2\u008e\u008f\7\60\2\2\u008f\u0090\7\62\2\2\u0090") + buf.write("\u0091\7\60\2\2\u0091\u0092\7\64\2\2\u0092(\3\2\2\2\u0093") + buf.write("\u0095\t\2\2\2\u0094\u0093\3\2\2\2\u0095\u0096\3\2\2\2") + buf.write("\u0096\u0094\3\2\2\2\u0096\u0097\3\2\2\2\u0097\u0098\3") + buf.write("\2\2\2\u0098\u0099\b\25\2\2\u0099*\3\2\2\2\u009a\u009b") + buf.write("\7\61\2\2\u009b\u009c\7\61\2\2\u009c\u00a0\3\2\2\2\u009d") + buf.write("\u009f\13\2\2\2\u009e\u009d\3\2\2\2\u009f\u00a2\3\2\2") + buf.write("\2\u00a0\u00a1\3\2\2\2\u00a0\u009e\3\2\2\2\u00a1\u00a3") + buf.write("\3\2\2\2\u00a2\u00a0\3\2\2\2\u00a3\u00a4\7\f\2\2\u00a4") + buf.write("\u00a5\3\2\2\2\u00a5\u00a6\b\26\2\2\u00a6,\3\2\2\2\u00a7") + buf.write("\u00a8\7\61\2\2\u00a8\u00a9\7,\2\2\u00a9\u00ad\3\2\2\2") + buf.write("\u00aa\u00ac\13\2\2\2\u00ab\u00aa\3\2\2\2\u00ac\u00af") + buf.write("\3\2\2\2\u00ad\u00ae\3\2\2\2\u00ad\u00ab\3\2\2\2\u00ae") + buf.write("\u00b0\3\2\2\2\u00af\u00ad\3\2\2\2\u00b0\u00b1\7,\2\2") + buf.write("\u00b1\u00b2\7\61\2\2\u00b2\u00b3\3\2\2\2\u00b3\u00b4") + buf.write("\b\27\2\2\u00b4.\3\2\2\2\u00b5\u00b6\7,\2\2\u00b6\60\3") + buf.write("\2\2\2\u00b7\u00b8\7\61\2\2\u00b8\62\3\2\2\2\u00b9\u00ba") + buf.write("\7-\2\2\u00ba\64\3\2\2\2\u00bb\u00bc\7/\2\2\u00bc\66\3") + buf.write("\2\2\2\u00bd\u00be\7>\2\2\u00be8\3\2\2\2\u00bf\u00c0\7") + buf.write("@\2\2\u00c0:\3\2\2\2\u00c1\u00c2\7>\2\2\u00c2\u00c3\7") + buf.write("?\2\2\u00c3<\3\2\2\2\u00c4\u00c5\7@\2\2\u00c5\u00c6\7") + buf.write("?\2\2\u00c6>\3\2\2\2\u00c7\u00c8\7?\2\2\u00c8\u00c9\7") + buf.write("?\2\2\u00c9@\3\2\2\2\u00ca\u00cb\7#\2\2\u00cb\u00cc\7") + buf.write("?\2\2\u00ccB\3\2\2\2\u00cd\u00ce\7B\2\2\u00ce\u00cf\5") + buf.write("S*\2\u00cfD\3\2\2\2\u00d0\u00d1\7\'\2\2\u00d1\u00d2\5") + buf.write("S*\2\u00d2F\3\2\2\2\u00d3\u00d4\7\'\2\2\u00d4\u00d5\5") + buf.write("O(\2\u00d5H\3\2\2\2\u00d6\u00d7\7o\2\2\u00d7\u00d8\7w") + buf.write("\2\2\u00d8\u00d9\7v\2\2\u00d9J\3\2\2\2\u00da\u00db\7V") + buf.write("\2\2\u00db\u00dc\7t\2\2\u00dc\u00dd\7w\2\2\u00dd\u00e4") + buf.write("\7g\2\2\u00de\u00df\7H\2\2\u00df\u00e0\7c\2\2\u00e0\u00e1") + buf.write("\7n\2\2\u00e1\u00e2\7u\2\2\u00e2\u00e4\7g\2\2\u00e3\u00da") + buf.write("\3\2\2\2\u00e3\u00de\3\2\2\2\u00e4L\3\2\2\2\u00e5\u00e6") + buf.write("\5O(\2\u00e6\u00e7\7\60\2\2\u00e7\u00e9\5O(\2\u00e8\u00ea") + buf.write("\5Q)\2\u00e9\u00e8\3\2\2\2\u00e9\u00ea\3\2\2\2\u00ea\u00ef") + buf.write("\3\2\2\2\u00eb\u00ec\5O(\2\u00ec\u00ed\5Q)\2\u00ed\u00ef") + buf.write("\3\2\2\2\u00ee\u00e5\3\2\2\2\u00ee\u00eb\3\2\2\2\u00ef") + buf.write("N\3\2\2\2\u00f0\u00f2\5W,\2\u00f1\u00f0\3\2\2\2\u00f2") + buf.write("\u00f3\3\2\2\2\u00f3\u00f1\3\2\2\2\u00f3\u00f4\3\2\2\2") + buf.write("\u00f4P\3\2\2\2\u00f5\u00f7\t\3\2\2\u00f6\u00f8\t\4\2") + buf.write("\2\u00f7\u00f6\3\2\2\2\u00f7\u00f8\3\2\2\2\u00f8\u00f9") + buf.write("\3\2\2\2\u00f9\u00fa\5O(\2\u00faR\3\2\2\2\u00fb\u00fe") + buf.write("\7a\2\2\u00fc\u00fe\5U+\2\u00fd\u00fb\3\2\2\2\u00fd\u00fc") + buf.write("\3\2\2\2\u00fe\u0104\3\2\2\2\u00ff\u0103\7a\2\2\u0100") + buf.write("\u0103\5U+\2\u0101\u0103\5W,\2\u0102\u00ff\3\2\2\2\u0102") + buf.write("\u0100\3\2\2\2\u0102\u0101\3\2\2\2\u0103\u0106\3\2\2\2") + buf.write("\u0104\u0102\3\2\2\2\u0104\u0105\3\2\2\2\u0105T\3\2\2") + buf.write("\2\u0106\u0104\3\2\2\2\u0107\u0108\t\5\2\2\u0108V\3\2") + buf.write("\2\2\u0109\u010a\t\6\2\2\u010aX\3\2\2\2\16\2\u0096\u00a0") + buf.write("\u00ad\u00e3\u00e9\u00ee\u00f3\u00f7\u00fd\u0102\u0104") + buf.write("\3\b\2\2") + return buf.getvalue() + + +class RelayLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + SEMVER = 19 + WS = 20 + LINE_COMMENT = 21 + COMMENT = 22 + MUL = 23 + DIV = 24 + ADD = 25 + SUB = 26 + LT = 27 + GT = 28 + LE = 29 + GE = 30 + EQ = 31 + NE = 32 + GLOBAL_VAR = 33 + LOCAL_VAR = 34 + GRAPH_VAR = 35 + MUT = 36 + BOOL_LIT = 37 + FLOAT = 38 + NAT = 39 + CNAME = 40 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "'('", "')'", "','", "'['", "']'", "'if'", "'else'", "'let'", + "'='", "';'", "'{'", "'}'", "'fn'", "'->'", "'def'", "':'", + "'Tensor'", "'_'", "'v0.0.2'", "'*'", "'/'", "'+'", "'-'", "'<'", + "'>'", "'<='", "'>='", "'=='", "'!='", "'mut'" ] + + symbolicNames = [ "", + "SEMVER", "WS", "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", + "SUB", "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", + "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", "NAT", "CNAME" ] + + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "SEMVER", "WS", "LINE_COMMENT", + "COMMENT", "MUL", "DIV", "ADD", "SUB", "LT", "GT", "LE", + "GE", "EQ", "NE", "GLOBAL_VAR", "LOCAL_VAR", "GRAPH_VAR", + "MUT", "BOOL_LIT", "FLOAT", "NAT", "EXP", "CNAME", "LETTER", + "DIGIT" ] + + grammarFileName = "Relay.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.7.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/python/tvm/relay/grammar/py3/RelayLexer.tokens b/python/tvm/relay/grammar/py3/RelayLexer.tokens new file mode 100644 index 000000000000..41f3ee62a86c --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayLexer.tokens @@ -0,0 +1,70 @@ +T__0=1 +T__1=2 +T__2=3 +T__3=4 +T__4=5 +T__5=6 +T__6=7 +T__7=8 +T__8=9 +T__9=10 +T__10=11 +T__11=12 +T__12=13 +T__13=14 +T__14=15 +T__15=16 +T__16=17 +T__17=18 +SEMVER=19 +WS=20 +LINE_COMMENT=21 +COMMENT=22 +MUL=23 +DIV=24 +ADD=25 +SUB=26 +LT=27 +GT=28 +LE=29 +GE=30 +EQ=31 +NE=32 +GLOBAL_VAR=33 +LOCAL_VAR=34 +GRAPH_VAR=35 +MUT=36 +BOOL_LIT=37 +FLOAT=38 +NAT=39 +CNAME=40 +'('=1 +')'=2 +','=3 +'['=4 +']'=5 +'if'=6 +'else'=7 +'let'=8 +'='=9 +';'=10 +'{'=11 +'}'=12 +'fn'=13 +'->'=14 +'def'=15 +':'=16 +'Tensor'=17 +'_'=18 +'v0.0.2'=19 +'*'=23 +'/'=24 +'+'=25 +'-'=26 +'<'=27 +'>'=28 +'<='=29 +'>='=30 +'=='=31 +'!='=32 +'mut'=36 diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py new file mode 100644 index 000000000000..ff5cffc36a9f --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -0,0 +1,2307 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# encoding: utf-8 +from antlr4 import * +from io import StringIO +from typing.io import TextIO +import sys + + +def serializedATN(): + with StringIO() as buf: + buf.write("\3\u608b\ua72a\u8133\ub9ed\u417c\u3be7\u7786\u5964\3*") + buf.write("\u014c\4\2\t\2\4\3\t\3\4\4\t\4\4\5\t\5\4\6\t\6\4\7\t\7") + buf.write("\4\b\t\b\4\t\t\t\4\n\t\n\4\13\t\13\4\f\t\f\4\r\t\r\4\16") + buf.write("\t\16\4\17\t\17\4\20\t\20\4\21\t\21\4\22\t\22\4\23\t\23") + buf.write("\3\2\3\2\3\3\3\3\7\3+\n\3\f\3\16\3.\13\3\3\3\5\3\61\n") + buf.write("\3\3\3\3\3\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\6\4H\n\4\r\4\16\4I\3") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\7\4R\n\4\f\4\16\4U\13\4\5\4W\n") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4d\n") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\5\4n\n\4\3\4\3\4\3") + buf.write("\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\5\4\u0080\n\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4") + buf.write("\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\3\4\7\4\u0096\n\4") + buf.write("\f\4\16\4\u0099\13\4\5\4\u009b\n\4\3\4\7\4\u009e\n\4\f") + buf.write("\4\16\4\u00a1\13\4\3\5\3\5\5\5\u00a5\n\5\3\5\3\5\3\5\3") + buf.write("\5\3\5\5\5\u00ac\n\5\3\5\3\5\3\6\3\6\3\6\5\6\u00b3\n\6") + buf.write("\3\6\3\6\3\6\3\6\3\6\5\6\u00ba\n\6\3\6\3\6\3\7\3\7\3\7") + buf.write("\3\7\3\7\3\7\5\7\u00c4\n\7\3\b\3\b\3\b\7\b\u00c9\n\b\f") + buf.write("\b\16\b\u00cc\13\b\5\b\u00ce\n\b\3\t\3\t\3\t\5\t\u00d3") + buf.write("\n\t\3\n\3\n\3\n\7\n\u00d8\n\n\f\n\16\n\u00db\13\n\5\n") + buf.write("\u00dd\n\n\3\13\3\13\3\13\3\13\3\f\3\f\3\f\3\f\3\f\3\f") + buf.write("\7\f\u00e9\n\f\f\f\16\f\u00ec\13\f\3\f\3\f\5\f\u00f0\n") + buf.write("\f\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\6\r\u00fd") + buf.write("\n\r\r\r\16\r\u00fe\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3\r\3") + buf.write("\r\3\r\3\r\3\r\5\r\u010d\n\r\3\r\3\r\3\r\3\r\7\r\u0113") + buf.write("\n\r\f\r\16\r\u0116\13\r\5\r\u0118\n\r\3\r\3\r\3\r\3\r") + buf.write("\3\r\5\r\u011f\n\r\3\16\3\16\3\16\3\16\3\16\3\16\3\16") + buf.write("\3\16\3\16\3\16\3\16\6\16\u012c\n\16\r\16\16\16\u012d") + buf.write("\3\16\3\16\5\16\u0132\n\16\3\17\3\17\3\17\3\17\3\17\5") + buf.write("\17\u0139\n\17\3\20\3\20\3\21\3\21\3\21\3\21\3\22\3\22") + buf.write("\3\22\5\22\u0144\n\22\3\23\3\23\3\23\3\23\5\23\u014a\n") + buf.write("\23\3\23\2\3\6\24\2\4\6\b\n\f\16\20\22\24\26\30\32\34") + buf.write("\36 \"$\2\6\3\2\31\32\3\2\33\34\3\2\35 \3\2!\"\2\u0175") + buf.write("\2&\3\2\2\2\4(\3\2\2\2\6\177\3\2\2\2\b\u00a2\3\2\2\2\n") + buf.write("\u00af\3\2\2\2\f\u00c3\3\2\2\2\16\u00cd\3\2\2\2\20\u00cf") + buf.write("\3\2\2\2\22\u00dc\3\2\2\2\24\u00de\3\2\2\2\26\u00ef\3") + buf.write("\2\2\2\30\u011e\3\2\2\2\32\u0131\3\2\2\2\34\u0138\3\2") + buf.write("\2\2\36\u013a\3\2\2\2 \u013c\3\2\2\2\"\u0143\3\2\2\2$") + buf.write("\u0149\3\2\2\2&\'\7*\2\2\'\3\3\2\2\2(\60\7\25\2\2)+\5") + buf.write("\n\6\2*)\3\2\2\2+.\3\2\2\2,*\3\2\2\2,-\3\2\2\2-\61\3\2") + buf.write("\2\2.,\3\2\2\2/\61\5\6\4\2\60,\3\2\2\2\60/\3\2\2\2\61") + buf.write("\62\3\2\2\2\62\63\7\2\2\3\63\5\3\2\2\2\64\65\b\4\1\2\65") + buf.write("\66\7\3\2\2\66\67\5\6\4\2\678\7\4\2\28\u0080\3\2\2\29") + buf.write(":\7\34\2\2:\u0080\5\6\4\23;\u0080\5\b\5\2<=\7\3\2\2=\u0080") + buf.write("\7\4\2\2>?\7\3\2\2?@\5\6\4\2@A\7\5\2\2AB\7\4\2\2B\u0080") + buf.write("\3\2\2\2CD\7\3\2\2DG\5\6\4\2EF\7\5\2\2FH\5\6\4\2GE\3\2") + buf.write("\2\2HI\3\2\2\2IG\3\2\2\2IJ\3\2\2\2JK\3\2\2\2KL\7\4\2\2") + buf.write("L\u0080\3\2\2\2MV\7\6\2\2NS\5\6\4\2OP\7\5\2\2PR\5\6\4") + buf.write("\2QO\3\2\2\2RU\3\2\2\2SQ\3\2\2\2ST\3\2\2\2TW\3\2\2\2U") + buf.write("S\3\2\2\2VN\3\2\2\2VW\3\2\2\2WX\3\2\2\2X\u0080\7\7\2\2") + buf.write("YZ\7\b\2\2Z[\7\3\2\2[\\\5\6\4\2\\]\7\4\2\2]^\5 \21\2^") + buf.write("_\7\t\2\2_`\5 \21\2`\u0080\3\2\2\2ac\7\n\2\2bd\7&\2\2") + buf.write("cb\3\2\2\2cd\3\2\2\2de\3\2\2\2ef\5\20\t\2fg\7\13\2\2g") + buf.write("h\5\6\4\2hi\7\f\2\2ij\5\6\4\bj\u0080\3\2\2\2km\7\n\2\2") + buf.write("ln\7&\2\2ml\3\2\2\2mn\3\2\2\2no\3\2\2\2op\5\20\t\2pq\7") + buf.write("\13\2\2qr\7\r\2\2rs\5\6\4\2st\7\16\2\2tu\7\f\2\2uv\5\6") + buf.write("\4\7v\u0080\3\2\2\2wx\5$\23\2xy\7\13\2\2yz\5\6\4\2z{\7") + buf.write("\f\2\2{|\5\6\4\5|\u0080\3\2\2\2}\u0080\5$\23\2~\u0080") + buf.write("\5\"\22\2\177\64\3\2\2\2\1779\3\2\2\2\177;\3\2\2\2\177") + buf.write("<\3\2\2\2\177>\3\2\2\2\177C\3\2\2\2\177M\3\2\2\2\177Y") + buf.write("\3\2\2\2\177a\3\2\2\2\177k\3\2\2\2\177w\3\2\2\2\177}\3") + buf.write("\2\2\2\177~\3\2\2\2\u0080\u009f\3\2\2\2\u0081\u0082\f") + buf.write("\22\2\2\u0082\u0083\t\2\2\2\u0083\u009e\5\6\4\23\u0084") + buf.write("\u0085\f\21\2\2\u0085\u0086\t\3\2\2\u0086\u009e\5\6\4") + buf.write("\22\u0087\u0088\f\20\2\2\u0088\u0089\t\4\2\2\u0089\u009e") + buf.write("\5\6\4\21\u008a\u008b\f\17\2\2\u008b\u008c\t\5\2\2\u008c") + buf.write("\u009e\5\6\4\20\u008d\u008e\f\6\2\2\u008e\u008f\7\f\2") + buf.write("\2\u008f\u009e\5\6\4\7\u0090\u0091\f\24\2\2\u0091\u009a") + buf.write("\7\3\2\2\u0092\u0097\5\6\4\2\u0093\u0094\7\5\2\2\u0094") + buf.write("\u0096\5\6\4\2\u0095\u0093\3\2\2\2\u0096\u0099\3\2\2\2") + buf.write("\u0097\u0095\3\2\2\2\u0097\u0098\3\2\2\2\u0098\u009b\3") + buf.write("\2\2\2\u0099\u0097\3\2\2\2\u009a\u0092\3\2\2\2\u009a\u009b") + buf.write("\3\2\2\2\u009b\u009c\3\2\2\2\u009c\u009e\7\4\2\2\u009d") + buf.write("\u0081\3\2\2\2\u009d\u0084\3\2\2\2\u009d\u0087\3\2\2\2") + buf.write("\u009d\u008a\3\2\2\2\u009d\u008d\3\2\2\2\u009d\u0090\3") + buf.write("\2\2\2\u009e\u00a1\3\2\2\2\u009f\u009d\3\2\2\2\u009f\u00a0") + buf.write("\3\2\2\2\u00a0\7\3\2\2\2\u00a1\u009f\3\2\2\2\u00a2\u00a4") + buf.write("\7\17\2\2\u00a3\u00a5\5\26\f\2\u00a4\u00a3\3\2\2\2\u00a4") + buf.write("\u00a5\3\2\2\2\u00a5\u00a6\3\2\2\2\u00a6\u00a7\7\3\2\2") + buf.write("\u00a7\u00a8\5\f\7\2\u00a8\u00ab\7\4\2\2\u00a9\u00aa\7") + buf.write("\20\2\2\u00aa\u00ac\5\30\r\2\u00ab\u00a9\3\2\2\2\u00ab") + buf.write("\u00ac\3\2\2\2\u00ac\u00ad\3\2\2\2\u00ad\u00ae\5 \21\2") + buf.write("\u00ae\t\3\2\2\2\u00af\u00b0\7\21\2\2\u00b0\u00b2\5$\23") + buf.write("\2\u00b1\u00b3\5\26\f\2\u00b2\u00b1\3\2\2\2\u00b2\u00b3") + buf.write("\3\2\2\2\u00b3\u00b4\3\2\2\2\u00b4\u00b5\7\3\2\2\u00b5") + buf.write("\u00b6\5\f\7\2\u00b6\u00b9\7\4\2\2\u00b7\u00b8\7\20\2") + buf.write("\2\u00b8\u00ba\5\30\r\2\u00b9\u00b7\3\2\2\2\u00b9\u00ba") + buf.write("\3\2\2\2\u00ba\u00bb\3\2\2\2\u00bb\u00bc\5 \21\2\u00bc") + buf.write("\13\3\2\2\2\u00bd\u00c4\5\16\b\2\u00be\u00c4\5\22\n\2") + buf.write("\u00bf\u00c0\5\16\b\2\u00c0\u00c1\7\5\2\2\u00c1\u00c2") + buf.write("\5\22\n\2\u00c2\u00c4\3\2\2\2\u00c3\u00bd\3\2\2\2\u00c3") + buf.write("\u00be\3\2\2\2\u00c3\u00bf\3\2\2\2\u00c4\r\3\2\2\2\u00c5") + buf.write("\u00ca\5\20\t\2\u00c6\u00c7\7\5\2\2\u00c7\u00c9\5\20\t") + buf.write("\2\u00c8\u00c6\3\2\2\2\u00c9\u00cc\3\2\2\2\u00ca\u00c8") + buf.write("\3\2\2\2\u00ca\u00cb\3\2\2\2\u00cb\u00ce\3\2\2\2\u00cc") + buf.write("\u00ca\3\2\2\2\u00cd\u00c5\3\2\2\2\u00cd\u00ce\3\2\2\2") + buf.write("\u00ce\17\3\2\2\2\u00cf\u00d2\5$\23\2\u00d0\u00d1\7\22") + buf.write("\2\2\u00d1\u00d3\5\30\r\2\u00d2\u00d0\3\2\2\2\u00d2\u00d3") + buf.write("\3\2\2\2\u00d3\21\3\2\2\2\u00d4\u00d9\5\24\13\2\u00d5") + buf.write("\u00d6\7\5\2\2\u00d6\u00d8\5\24\13\2\u00d7\u00d5\3\2\2") + buf.write("\2\u00d8\u00db\3\2\2\2\u00d9\u00d7\3\2\2\2\u00d9\u00da") + buf.write("\3\2\2\2\u00da\u00dd\3\2\2\2\u00db\u00d9\3\2\2\2\u00dc") + buf.write("\u00d4\3\2\2\2\u00dc\u00dd\3\2\2\2\u00dd\23\3\2\2\2\u00de") + buf.write("\u00df\7*\2\2\u00df\u00e0\7\13\2\2\u00e0\u00e1\5\6\4\2") + buf.write("\u00e1\25\3\2\2\2\u00e2\u00e3\7\6\2\2\u00e3\u00f0\7\7") + buf.write("\2\2\u00e4\u00e5\7\6\2\2\u00e5\u00ea\5$\23\2\u00e6\u00e7") + buf.write("\7\5\2\2\u00e7\u00e9\5$\23\2\u00e8\u00e6\3\2\2\2\u00e9") + buf.write("\u00ec\3\2\2\2\u00ea\u00e8\3\2\2\2\u00ea\u00eb\3\2\2\2") + buf.write("\u00eb\u00ed\3\2\2\2\u00ec\u00ea\3\2\2\2\u00ed\u00ee\7") + buf.write("\7\2\2\u00ee\u00f0\3\2\2\2\u00ef\u00e2\3\2\2\2\u00ef\u00e4") + buf.write("\3\2\2\2\u00f0\27\3\2\2\2\u00f1\u00f2\7\3\2\2\u00f2\u011f") + buf.write("\7\4\2\2\u00f3\u00f4\7\3\2\2\u00f4\u00f5\5\30\r\2\u00f5") + buf.write("\u00f6\7\5\2\2\u00f6\u00f7\7\4\2\2\u00f7\u011f\3\2\2\2") + buf.write("\u00f8\u00f9\7\3\2\2\u00f9\u00fc\5\30\r\2\u00fa\u00fb") + buf.write("\7\5\2\2\u00fb\u00fd\5\30\r\2\u00fc\u00fa\3\2\2\2\u00fd") + buf.write("\u00fe\3\2\2\2\u00fe\u00fc\3\2\2\2\u00fe\u00ff\3\2\2\2") + buf.write("\u00ff\u0100\3\2\2\2\u0100\u0101\7\4\2\2\u0101\u011f\3") + buf.write("\2\2\2\u0102\u011f\5\36\20\2\u0103\u0104\7\23\2\2\u0104") + buf.write("\u0105\7\6\2\2\u0105\u0106\5\32\16\2\u0106\u0107\7\5\2") + buf.write("\2\u0107\u0108\5\30\r\2\u0108\u0109\7\7\2\2\u0109\u011f") + buf.write("\3\2\2\2\u010a\u010c\7\17\2\2\u010b\u010d\5\26\f\2\u010c") + buf.write("\u010b\3\2\2\2\u010c\u010d\3\2\2\2\u010d\u010e\3\2\2\2") + buf.write("\u010e\u0117\7\3\2\2\u010f\u0114\5\30\r\2\u0110\u0111") + buf.write("\7\5\2\2\u0111\u0113\5\30\r\2\u0112\u0110\3\2\2\2\u0113") + buf.write("\u0116\3\2\2\2\u0114\u0112\3\2\2\2\u0114\u0115\3\2\2\2") + buf.write("\u0115\u0118\3\2\2\2\u0116\u0114\3\2\2\2\u0117\u010f\3") + buf.write("\2\2\2\u0117\u0118\3\2\2\2\u0118\u0119\3\2\2\2\u0119\u011a") + buf.write("\7\4\2\2\u011a\u011b\7\20\2\2\u011b\u011f\5\30\r\2\u011c") + buf.write("\u011f\7\24\2\2\u011d\u011f\7)\2\2\u011e\u00f1\3\2\2\2") + buf.write("\u011e\u00f3\3\2\2\2\u011e\u00f8\3\2\2\2\u011e\u0102\3") + buf.write("\2\2\2\u011e\u0103\3\2\2\2\u011e\u010a\3\2\2\2\u011e\u011c") + buf.write("\3\2\2\2\u011e\u011d\3\2\2\2\u011f\31\3\2\2\2\u0120\u0121") + buf.write("\7\3\2\2\u0121\u0132\7\4\2\2\u0122\u0123\7\3\2\2\u0123") + buf.write("\u0124\5\34\17\2\u0124\u0125\7\5\2\2\u0125\u0126\7\4\2") + buf.write("\2\u0126\u0132\3\2\2\2\u0127\u0128\7\3\2\2\u0128\u012b") + buf.write("\5\34\17\2\u0129\u012a\7\5\2\2\u012a\u012c\5\34\17\2\u012b") + buf.write("\u0129\3\2\2\2\u012c\u012d\3\2\2\2\u012d\u012b\3\2\2\2") + buf.write("\u012d\u012e\3\2\2\2\u012e\u012f\3\2\2\2\u012f\u0130\7") + buf.write("\4\2\2\u0130\u0132\3\2\2\2\u0131\u0120\3\2\2\2\u0131\u0122") + buf.write("\3\2\2\2\u0131\u0127\3\2\2\2\u0132\33\3\2\2\2\u0133\u0134") + buf.write("\7\3\2\2\u0134\u0135\5\34\17\2\u0135\u0136\7\4\2\2\u0136") + buf.write("\u0139\3\2\2\2\u0137\u0139\7)\2\2\u0138\u0133\3\2\2\2") + buf.write("\u0138\u0137\3\2\2\2\u0139\35\3\2\2\2\u013a\u013b\7*\2") + buf.write("\2\u013b\37\3\2\2\2\u013c\u013d\7\r\2\2\u013d\u013e\5") + buf.write("\6\4\2\u013e\u013f\7\16\2\2\u013f!\3\2\2\2\u0140\u0144") + buf.write("\7(\2\2\u0141\u0144\7)\2\2\u0142\u0144\7\'\2\2\u0143\u0140") + buf.write("\3\2\2\2\u0143\u0141\3\2\2\2\u0143\u0142\3\2\2\2\u0144") + buf.write("#\3\2\2\2\u0145\u014a\5\2\2\2\u0146\u014a\7#\2\2\u0147") + buf.write("\u014a\7$\2\2\u0148\u014a\7%\2\2\u0149\u0145\3\2\2\2\u0149") + buf.write("\u0146\3\2\2\2\u0149\u0147\3\2\2\2\u0149\u0148\3\2\2\2") + buf.write("\u014a%\3\2\2\2$,\60ISVcm\177\u0097\u009a\u009d\u009f") + buf.write("\u00a4\u00ab\u00b2\u00b9\u00c3\u00ca\u00cd\u00d2\u00d9") + buf.write("\u00dc\u00ea\u00ef\u00fe\u010c\u0114\u0117\u011e\u012d") + buf.write("\u0131\u0138\u0143\u0149") + return buf.getvalue() + + +class RelayParser ( Parser ): + + grammarFileName = "Relay.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "'('", "')'", "','", "'['", "']'", "'if'", + "'else'", "'let'", "'='", "';'", "'{'", "'}'", "'fn'", + "'->'", "'def'", "':'", "'Tensor'", "'_'", "'v0.0.2'", + "", "", "", "'*'", "'/'", + "'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='", + "'!='", "", "", "", "'mut'" ] + + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "SEMVER", "WS", + "LINE_COMMENT", "COMMENT", "MUL", "DIV", "ADD", "SUB", + "LT", "GT", "LE", "GE", "EQ", "NE", "GLOBAL_VAR", + "LOCAL_VAR", "GRAPH_VAR", "MUT", "BOOL_LIT", "FLOAT", + "NAT", "CNAME" ] + + RULE_opIdent = 0 + RULE_prog = 1 + RULE_expr = 2 + RULE_func = 3 + RULE_defn = 4 + RULE_argList = 5 + RULE_varList = 6 + RULE_var = 7 + RULE_attrList = 8 + RULE_attr = 9 + RULE_typeParamSeq = 10 + RULE_type_ = 11 + RULE_shapeSeq = 12 + RULE_shape = 13 + RULE_typeIdent = 14 + RULE_body = 15 + RULE_scalar = 16 + RULE_ident = 17 + + ruleNames = [ "opIdent", "prog", "expr", "func", "defn", "argList", + "varList", "var", "attrList", "attr", "typeParamSeq", + "type_", "shapeSeq", "shape", "typeIdent", "body", "scalar", + "ident" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + SEMVER=19 + WS=20 + LINE_COMMENT=21 + COMMENT=22 + MUL=23 + DIV=24 + ADD=25 + SUB=26 + LT=27 + GT=28 + LE=29 + GE=30 + EQ=31 + NE=32 + GLOBAL_VAR=33 + LOCAL_VAR=34 + GRAPH_VAR=35 + MUT=36 + BOOL_LIT=37 + FLOAT=38 + NAT=39 + CNAME=40 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.7.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class OpIdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_opIdent + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitOpIdent" ): + return visitor.visitOpIdent(self) + else: + return visitor.visitChildren(self) + + + + + def opIdent(self): + + localctx = RelayParser.OpIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_opIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 36 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ProgContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def SEMVER(self): + return self.getToken(RelayParser.SEMVER, 0) + + def EOF(self): + return self.getToken(RelayParser.EOF, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def defn(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.DefnContext) + else: + return self.getTypedRuleContext(RelayParser.DefnContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_prog + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitProg" ): + return visitor.visitProg(self) + else: + return visitor.visitChildren(self) + + + + + def prog(self): + + localctx = RelayParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 38 + self.match(RelayParser.SEMVER) + self.state = 46 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.EOF, RelayParser.T__14]: + self.state = 42 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__14: + self.state = 39 + self.defn() + self.state = 44 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [RelayParser.T__0, RelayParser.T__3, RelayParser.T__5, RelayParser.T__7, RelayParser.T__12, RelayParser.SUB, RelayParser.GLOBAL_VAR, RelayParser.LOCAL_VAR, RelayParser.GRAPH_VAR, RelayParser.BOOL_LIT, RelayParser.FLOAT, RelayParser.NAT, RelayParser.CNAME]: + self.state = 45 + self.expr(0) + pass + else: + raise NoViableAltException(self) + + self.state = 48 + self.match(RelayParser.EOF) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_expr + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + class IdentExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdentExpr" ): + return visitor.visitIdentExpr(self) + else: + return visitor.visitChildren(self) + + + class CallContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitCall" ): + return visitor.visitCall(self) + else: + return visitor.visitChildren(self) + + + class NegContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNeg" ): + return visitor.visitNeg(self) + else: + return visitor.visitChildren(self) + + + class TupleContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTuple" ): + return visitor.visitTuple(self) + else: + return visitor.visitChildren(self) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitParens" ): + return visitor.visitParens(self) + else: + return visitor.visitChildren(self) + + + class FuncExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def func(self): + return self.getTypedRuleContext(RelayParser.FuncContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFuncExpr" ): + return visitor.visitFuncExpr(self) + else: + return visitor.visitChildren(self) + + + class ScalarExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def scalar(self): + return self.getTypedRuleContext(RelayParser.ScalarContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarExpr" ): + return visitor.visitScalarExpr(self) + else: + return visitor.visitChildren(self) + + + class LetContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def var(self): + return self.getTypedRuleContext(RelayParser.VarContext,0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUT(self): + return self.getToken(RelayParser.MUT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitLet" ): + return visitor.visitLet(self) + else: + return visitor.visitChildren(self) + + + class TensorContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTensor" ): + return visitor.visitTensor(self) + else: + return visitor.visitChildren(self) + + + class IfElseContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + def body(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.BodyContext) + else: + return self.getTypedRuleContext(RelayParser.BodyContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIfElse" ): + return visitor.visitIfElse(self) + else: + return visitor.visitChildren(self) + + + class GraphContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitGraph" ): + return visitor.visitGraph(self) + else: + return visitor.visitChildren(self) + + + class BinOpContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ExprContext + super().__init__(parser) + self.op = None # Token + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ExprContext) + else: + return self.getTypedRuleContext(RelayParser.ExprContext,i) + + def MUL(self): + return self.getToken(RelayParser.MUL, 0) + def DIV(self): + return self.getToken(RelayParser.DIV, 0) + def ADD(self): + return self.getToken(RelayParser.ADD, 0) + def SUB(self): + return self.getToken(RelayParser.SUB, 0) + def LT(self): + return self.getToken(RelayParser.LT, 0) + def GT(self): + return self.getToken(RelayParser.GT, 0) + def LE(self): + return self.getToken(RelayParser.LE, 0) + def GE(self): + return self.getToken(RelayParser.GE, 0) + def EQ(self): + return self.getToken(RelayParser.EQ, 0) + def NE(self): + return self.getToken(RelayParser.NE, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitBinOp" ): + return visitor.visitBinOp(self) + else: + return visitor.visitChildren(self) + + + + def expr(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = RelayParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 4 + self.enterRecursionRule(localctx, 4, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 125 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + if la_ == 1: + localctx = RelayParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 51 + self.match(RelayParser.T__0) + self.state = 52 + self.expr(0) + self.state = 53 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.NegContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 55 + self.match(RelayParser.SUB) + self.state = 56 + self.expr(17) + pass + + elif la_ == 3: + localctx = RelayParser.FuncExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 57 + self.func() + pass + + elif la_ == 4: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 58 + self.match(RelayParser.T__0) + self.state = 59 + self.match(RelayParser.T__1) + pass + + elif la_ == 5: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 60 + self.match(RelayParser.T__0) + self.state = 61 + self.expr(0) + self.state = 62 + self.match(RelayParser.T__2) + self.state = 63 + self.match(RelayParser.T__1) + pass + + elif la_ == 6: + localctx = RelayParser.TupleContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 65 + self.match(RelayParser.T__0) + self.state = 66 + self.expr(0) + self.state = 69 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 67 + self.match(RelayParser.T__2) + self.state = 68 + self.expr(0) + self.state = 71 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 73 + self.match(RelayParser.T__1) + pass + + elif la_ == 7: + localctx = RelayParser.TensorContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 75 + self.match(RelayParser.T__3) + self.state = 84 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 76 + self.expr(0) + self.state = 81 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 77 + self.match(RelayParser.T__2) + self.state = 78 + self.expr(0) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 86 + self.match(RelayParser.T__4) + pass + + elif la_ == 8: + localctx = RelayParser.IfElseContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 87 + self.match(RelayParser.T__5) + self.state = 88 + self.match(RelayParser.T__0) + self.state = 89 + self.expr(0) + self.state = 90 + self.match(RelayParser.T__1) + self.state = 91 + self.body() + self.state = 92 + self.match(RelayParser.T__6) + self.state = 93 + self.body() + pass + + elif la_ == 9: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 95 + self.match(RelayParser.T__7) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 96 + self.match(RelayParser.MUT) + + + self.state = 99 + self.var() + self.state = 100 + self.match(RelayParser.T__8) + self.state = 101 + self.expr(0) + self.state = 102 + self.match(RelayParser.T__9) + self.state = 103 + self.expr(6) + pass + + elif la_ == 10: + localctx = RelayParser.LetContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 105 + self.match(RelayParser.T__7) + self.state = 107 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.MUT: + self.state = 106 + self.match(RelayParser.MUT) + + + self.state = 109 + self.var() + self.state = 110 + self.match(RelayParser.T__8) + self.state = 111 + self.match(RelayParser.T__10) + self.state = 112 + self.expr(0) + self.state = 113 + self.match(RelayParser.T__11) + self.state = 114 + self.match(RelayParser.T__9) + self.state = 115 + self.expr(5) + pass + + elif la_ == 11: + localctx = RelayParser.GraphContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 117 + self.ident() + self.state = 118 + self.match(RelayParser.T__8) + self.state = 119 + self.expr(0) + self.state = 120 + self.match(RelayParser.T__9) + self.state = 121 + self.expr(3) + pass + + elif la_ == 12: + localctx = RelayParser.IdentExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 123 + self.ident() + pass + + elif la_ == 13: + localctx = RelayParser.ScalarExprContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 124 + self.scalar() + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 157 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 155 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,10,self._ctx) + if la_ == 1: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 127 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 128 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.MUL or _la==RelayParser.DIV): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 129 + self.expr(17) + pass + + elif la_ == 2: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 130 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 131 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.ADD or _la==RelayParser.SUB): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 132 + self.expr(16) + pass + + elif la_ == 3: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 133 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 134 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not((((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.LT) | (1 << RelayParser.GT) | (1 << RelayParser.LE) | (1 << RelayParser.GE))) != 0)): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.expr(15) + pass + + elif la_ == 4: + localctx = RelayParser.BinOpContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 136 + if not self.precpred(self._ctx, 13): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 13)") + self.state = 137 + localctx.op = self._input.LT(1) + _la = self._input.LA(1) + if not(_la==RelayParser.EQ or _la==RelayParser.NE): + localctx.op = self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 138 + self.expr(14) + pass + + elif la_ == 5: + localctx = RelayParser.LetContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 139 + if not self.precpred(self._ctx, 4): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 140 + self.match(RelayParser.T__9) + self.state = 141 + self.expr(5) + pass + + elif la_ == 6: + localctx = RelayParser.CallContext(self, RelayParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 142 + if not self.precpred(self._ctx, 18): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 18)") + self.state = 143 + self.match(RelayParser.T__0) + self.state = 152 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__3) | (1 << RelayParser.T__5) | (1 << RelayParser.T__7) | (1 << RelayParser.T__12) | (1 << RelayParser.SUB) | (1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.BOOL_LIT) | (1 << RelayParser.FLOAT) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 144 + self.expr(0) + self.state = 149 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 145 + self.match(RelayParser.T__2) + self.state = 146 + self.expr(0) + self.state = 151 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 154 + self.match(RelayParser.T__1) + pass + + + self.state = 159 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class FuncContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_func + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFunc" ): + return visitor.visitFunc(self) + else: + return visitor.visitChildren(self) + + + + + def func(self): + + localctx = RelayParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_func) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 160 + self.match(RelayParser.T__12) + self.state = 162 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 161 + self.typeParamSeq() + + + self.state = 164 + self.match(RelayParser.T__0) + self.state = 165 + self.argList() + self.state = 166 + self.match(RelayParser.T__1) + self.state = 169 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 167 + self.match(RelayParser.T__13) + self.state = 168 + self.type_() + + + self.state = 171 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DefnContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def argList(self): + return self.getTypedRuleContext(RelayParser.ArgListContext,0) + + + def body(self): + return self.getTypedRuleContext(RelayParser.BodyContext,0) + + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_defn + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitDefn" ): + return visitor.visitDefn(self) + else: + return visitor.visitChildren(self) + + + + + def defn(self): + + localctx = RelayParser.DefnContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_defn) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 173 + self.match(RelayParser.T__14) + self.state = 174 + self.ident() + self.state = 176 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 175 + self.typeParamSeq() + + + self.state = 178 + self.match(RelayParser.T__0) + self.state = 179 + self.argList() + self.state = 180 + self.match(RelayParser.T__1) + self.state = 183 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__13: + self.state = 181 + self.match(RelayParser.T__13) + self.state = 182 + self.type_() + + + self.state = 185 + self.body() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varList(self): + return self.getTypedRuleContext(RelayParser.VarListContext,0) + + + def attrList(self): + return self.getTypedRuleContext(RelayParser.AttrListContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_argList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitArgList" ): + return visitor.visitArgList(self) + else: + return visitor.visitChildren(self) + + + + + def argList(self): + + localctx = RelayParser.ArgListContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_argList) + try: + self.state = 193 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,16,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 187 + self.varList() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 188 + self.attrList() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 189 + self.varList() + self.state = 190 + self.match(RelayParser.T__2) + self.state = 191 + self.attrList() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def var(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.VarContext) + else: + return self.getTypedRuleContext(RelayParser.VarContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_varList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitVarList" ): + return visitor.visitVarList(self) + else: + return visitor.visitChildren(self) + + + + + def varList(self): + + localctx = RelayParser.VarListContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_varList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.GLOBAL_VAR) | (1 << RelayParser.LOCAL_VAR) | (1 << RelayParser.GRAPH_VAR) | (1 << RelayParser.CNAME))) != 0): + self.state = 195 + self.var() + self.state = 200 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 196 + self.match(RelayParser.T__2) + self.state = 197 + self.var() + self.state = 202 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,17,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self): + return self.getTypedRuleContext(RelayParser.IdentContext,0) + + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def getRuleIndex(self): + return RelayParser.RULE_var + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitVar" ): + return visitor.visitVar(self) + else: + return visitor.visitChildren(self) + + + + + def var(self): + + localctx = RelayParser.VarContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_var) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 205 + self.ident() + self.state = 208 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__15: + self.state = 206 + self.match(RelayParser.T__15) + self.state = 207 + self.type_() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrListContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def attr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.AttrContext) + else: + return self.getTypedRuleContext(RelayParser.AttrContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_attrList + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAttrList" ): + return visitor.visitAttrList(self) + else: + return visitor.visitChildren(self) + + + + + def attrList(self): + + localctx = RelayParser.AttrListContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_attrList) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.CNAME: + self.state = 210 + self.attr() + self.state = 215 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 211 + self.match(RelayParser.T__2) + self.state = 212 + self.attr() + self.state = 217 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AttrContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_attr + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAttr" ): + return visitor.visitAttr(self) + else: + return visitor.visitChildren(self) + + + + + def attr(self): + + localctx = RelayParser.AttrContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_attr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 220 + self.match(RelayParser.CNAME) + self.state = 221 + self.match(RelayParser.T__8) + self.state = 222 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeParamSeqContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ident(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.IdentContext) + else: + return self.getTypedRuleContext(RelayParser.IdentContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_typeParamSeq + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeParamSeq" ): + return visitor.visitTypeParamSeq(self) + else: + return visitor.visitChildren(self) + + + + + def typeParamSeq(self): + + localctx = RelayParser.TypeParamSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_typeParamSeq) + self._la = 0 # Token type + try: + self.state = 237 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,23,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 224 + self.match(RelayParser.T__3) + self.state = 225 + self.match(RelayParser.T__4) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 226 + self.match(RelayParser.T__3) + self.state = 227 + self.ident() + self.state = 232 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 228 + self.match(RelayParser.T__2) + self.state = 229 + self.ident() + self.state = 234 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 235 + self.match(RelayParser.T__4) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Type_Context(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_type_ + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class IntTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIntType" ): + return visitor.visitIntType(self) + else: + return visitor.visitChildren(self) + + + class TupleTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def type_(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTupleType" ): + return visitor.visitTupleType(self) + else: + return visitor.visitChildren(self) + + + class TypeIdentTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def typeIdent(self): + return self.getTypedRuleContext(RelayParser.TypeIdentContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeIdentType" ): + return visitor.visitTypeIdentType(self) + else: + return visitor.visitChildren(self) + + + class IncompleteTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIncompleteType" ): + return visitor.visitIncompleteType(self) + else: + return visitor.visitChildren(self) + + + class TensorTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def shapeSeq(self): + return self.getTypedRuleContext(RelayParser.ShapeSeqContext,0) + + def type_(self): + return self.getTypedRuleContext(RelayParser.Type_Context,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTensorType" ): + return visitor.visitTensorType(self) + else: + return visitor.visitChildren(self) + + + class FuncTypeContext(Type_Context): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.Type_Context + super().__init__(parser) + self.copyFrom(ctx) + + def type_(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.Type_Context) + else: + return self.getTypedRuleContext(RelayParser.Type_Context,i) + + def typeParamSeq(self): + return self.getTypedRuleContext(RelayParser.TypeParamSeqContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFuncType" ): + return visitor.visitFuncType(self) + else: + return visitor.visitChildren(self) + + + + def type_(self): + + localctx = RelayParser.Type_Context(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_type_) + self._la = 0 # Token type + try: + self.state = 284 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 239 + self.match(RelayParser.T__0) + self.state = 240 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 241 + self.match(RelayParser.T__0) + self.state = 242 + self.type_() + self.state = 243 + self.match(RelayParser.T__2) + self.state = 244 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + localctx = RelayParser.TupleTypeContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.match(RelayParser.T__0) + self.state = 247 + self.type_() + self.state = 250 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 248 + self.match(RelayParser.T__2) + self.state = 249 + self.type_() + self.state = 252 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 254 + self.match(RelayParser.T__1) + pass + + elif la_ == 4: + localctx = RelayParser.TypeIdentTypeContext(self, localctx) + self.enterOuterAlt(localctx, 4) + self.state = 256 + self.typeIdent() + pass + + elif la_ == 5: + localctx = RelayParser.TensorTypeContext(self, localctx) + self.enterOuterAlt(localctx, 5) + self.state = 257 + self.match(RelayParser.T__16) + self.state = 258 + self.match(RelayParser.T__3) + self.state = 259 + self.shapeSeq() + self.state = 260 + self.match(RelayParser.T__2) + self.state = 261 + self.type_() + self.state = 262 + self.match(RelayParser.T__4) + pass + + elif la_ == 6: + localctx = RelayParser.FuncTypeContext(self, localctx) + self.enterOuterAlt(localctx, 6) + self.state = 264 + self.match(RelayParser.T__12) + self.state = 266 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==RelayParser.T__3: + self.state = 265 + self.typeParamSeq() + + + self.state = 268 + self.match(RelayParser.T__0) + self.state = 277 + self._errHandler.sync(self) + _la = self._input.LA(1) + if (((_la) & ~0x3f) == 0 and ((1 << _la) & ((1 << RelayParser.T__0) | (1 << RelayParser.T__12) | (1 << RelayParser.T__16) | (1 << RelayParser.T__17) | (1 << RelayParser.NAT) | (1 << RelayParser.CNAME))) != 0): + self.state = 269 + self.type_() + self.state = 274 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==RelayParser.T__2: + self.state = 270 + self.match(RelayParser.T__2) + self.state = 271 + self.type_() + self.state = 276 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 279 + self.match(RelayParser.T__1) + self.state = 280 + self.match(RelayParser.T__13) + self.state = 281 + self.type_() + pass + + elif la_ == 7: + localctx = RelayParser.IncompleteTypeContext(self, localctx) + self.enterOuterAlt(localctx, 7) + self.state = 282 + self.match(RelayParser.T__17) + pass + + elif la_ == 8: + localctx = RelayParser.IntTypeContext(self, localctx) + self.enterOuterAlt(localctx, 8) + self.state = 283 + self.match(RelayParser.NAT) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeSeqContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def shape(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(RelayParser.ShapeContext) + else: + return self.getTypedRuleContext(RelayParser.ShapeContext,i) + + + def getRuleIndex(self): + return RelayParser.RULE_shapeSeq + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShapeSeq" ): + return visitor.visitShapeSeq(self) + else: + return visitor.visitChildren(self) + + + + + def shapeSeq(self): + + localctx = RelayParser.ShapeSeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_shapeSeq) + self._la = 0 # Token type + try: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 286 + self.match(RelayParser.T__0) + self.state = 287 + self.match(RelayParser.T__1) + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 288 + self.match(RelayParser.T__0) + self.state = 289 + self.shape() + self.state = 290 + self.match(RelayParser.T__2) + self.state = 291 + self.match(RelayParser.T__1) + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 293 + self.match(RelayParser.T__0) + self.state = 294 + self.shape() + self.state = 297 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 295 + self.match(RelayParser.T__2) + self.state = 296 + self.shape() + self.state = 299 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==RelayParser.T__2): + break + + self.state = 301 + self.match(RelayParser.T__1) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ShapeContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_shape + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class ParensShapeContext(ShapeContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext + super().__init__(parser) + self.copyFrom(ctx) + + def shape(self): + return self.getTypedRuleContext(RelayParser.ShapeContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitParensShape" ): + return visitor.visitParensShape(self) + else: + return visitor.visitChildren(self) + + + class IntShapeContext(ShapeContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ShapeContext + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIntShape" ): + return visitor.visitIntShape(self) + else: + return visitor.visitChildren(self) + + + + def shape(self): + + localctx = RelayParser.ShapeContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_shape) + try: + self.state = 310 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.T__0]: + localctx = RelayParser.ParensShapeContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 305 + self.match(RelayParser.T__0) + self.state = 306 + self.shape() + self.state = 307 + self.match(RelayParser.T__1) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.IntShapeContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 309 + self.match(RelayParser.NAT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TypeIdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CNAME(self): + return self.getToken(RelayParser.CNAME, 0) + + def getRuleIndex(self): + return RelayParser.RULE_typeIdent + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTypeIdent" ): + return visitor.visitTypeIdent(self) + else: + return visitor.visitChildren(self) + + + + + def typeIdent(self): + + localctx = RelayParser.TypeIdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_typeIdent) + try: + self.enterOuterAlt(localctx, 1) + self.state = 312 + self.match(RelayParser.CNAME) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BodyContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(RelayParser.ExprContext,0) + + + def getRuleIndex(self): + return RelayParser.RULE_body + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitBody" ): + return visitor.visitBody(self) + else: + return visitor.visitChildren(self) + + + + + def body(self): + + localctx = RelayParser.BodyContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_body) + try: + self.enterOuterAlt(localctx, 1) + self.state = 314 + self.match(RelayParser.T__10) + self.state = 315 + self.expr(0) + self.state = 316 + self.match(RelayParser.T__11) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ScalarContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return RelayParser.RULE_scalar + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class ScalarFloatContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(RelayParser.FLOAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarFloat" ): + return visitor.visitScalarFloat(self) + else: + return visitor.visitChildren(self) + + + class ScalarBoolContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def BOOL_LIT(self): + return self.getToken(RelayParser.BOOL_LIT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarBool" ): + return visitor.visitScalarBool(self) + else: + return visitor.visitChildren(self) + + + class ScalarIntContext(ScalarContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a RelayParser.ScalarContext + super().__init__(parser) + self.copyFrom(ctx) + + def NAT(self): + return self.getToken(RelayParser.NAT, 0) + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitScalarInt" ): + return visitor.visitScalarInt(self) + else: + return visitor.visitChildren(self) + + + + def scalar(self): + + localctx = RelayParser.ScalarContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_scalar) + try: + self.state = 321 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.FLOAT]: + localctx = RelayParser.ScalarFloatContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.match(RelayParser.FLOAT) + pass + elif token in [RelayParser.NAT]: + localctx = RelayParser.ScalarIntContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 319 + self.match(RelayParser.NAT) + pass + elif token in [RelayParser.BOOL_LIT]: + localctx = RelayParser.ScalarBoolContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 320 + self.match(RelayParser.BOOL_LIT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IdentContext(ParserRuleContext): + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def opIdent(self): + return self.getTypedRuleContext(RelayParser.OpIdentContext,0) + + + def GLOBAL_VAR(self): + return self.getToken(RelayParser.GLOBAL_VAR, 0) + + def LOCAL_VAR(self): + return self.getToken(RelayParser.LOCAL_VAR, 0) + + def GRAPH_VAR(self): + return self.getToken(RelayParser.GRAPH_VAR, 0) + + def getRuleIndex(self): + return RelayParser.RULE_ident + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdent" ): + return visitor.visitIdent(self) + else: + return visitor.visitChildren(self) + + + + + def ident(self): + + localctx = RelayParser.IdentContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_ident) + try: + self.state = 327 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [RelayParser.CNAME]: + self.enterOuterAlt(localctx, 1) + self.state = 323 + self.opIdent() + pass + elif token in [RelayParser.GLOBAL_VAR]: + self.enterOuterAlt(localctx, 2) + self.state = 324 + self.match(RelayParser.GLOBAL_VAR) + pass + elif token in [RelayParser.LOCAL_VAR]: + self.enterOuterAlt(localctx, 3) + self.state = 325 + self.match(RelayParser.LOCAL_VAR) + pass + elif token in [RelayParser.GRAPH_VAR]: + self.enterOuterAlt(localctx, 4) + self.state = 326 + self.match(RelayParser.GRAPH_VAR) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[2] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 13) + + + if predIndex == 4: + return self.precpred(self._ctx, 4) + + + if predIndex == 5: + return self.precpred(self._ctx, 18) + + + + + diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py new file mode 100644 index 000000000000..64308dca1a3a --- /dev/null +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -0,0 +1,198 @@ +# Generated from /home/sslyu/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +from antlr4 import * +if __name__ is not None and "." in __name__: + from .RelayParser import RelayParser +else: + from RelayParser import RelayParser + +# This class defines a complete generic visitor for a parse tree produced by RelayParser. + +class RelayVisitor(ParseTreeVisitor): + + # Visit a parse tree produced by RelayParser#opIdent. + def visitOpIdent(self, ctx:RelayParser.OpIdentContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#prog. + def visitProg(self, ctx:RelayParser.ProgContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#identExpr. + def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#call. + def visitCall(self, ctx:RelayParser.CallContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#neg. + def visitNeg(self, ctx:RelayParser.NegContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tuple. + def visitTuple(self, ctx:RelayParser.TupleContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parens. + def visitParens(self, ctx:RelayParser.ParensContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcExpr. + def visitFuncExpr(self, ctx:RelayParser.FuncExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarExpr. + def visitScalarExpr(self, ctx:RelayParser.ScalarExprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#let. + def visitLet(self, ctx:RelayParser.LetContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensor. + def visitTensor(self, ctx:RelayParser.TensorContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ifElse. + def visitIfElse(self, ctx:RelayParser.IfElseContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#graph. + def visitGraph(self, ctx:RelayParser.GraphContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#binOp. + def visitBinOp(self, ctx:RelayParser.BinOpContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#func. + def visitFunc(self, ctx:RelayParser.FuncContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#defn. + def visitDefn(self, ctx:RelayParser.DefnContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#argList. + def visitArgList(self, ctx:RelayParser.ArgListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#varList. + def visitVarList(self, ctx:RelayParser.VarListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#var. + def visitVar(self, ctx:RelayParser.VarContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attrList. + def visitAttrList(self, ctx:RelayParser.AttrListContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#attr. + def visitAttr(self, ctx:RelayParser.AttrContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeParamSeq. + def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tupleType. + def visitTupleType(self, ctx:RelayParser.TupleTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdentType. + def visitTypeIdentType(self, ctx:RelayParser.TypeIdentTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#tensorType. + def visitTensorType(self, ctx:RelayParser.TensorTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#funcType. + def visitFuncType(self, ctx:RelayParser.FuncTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#incompleteType. + def visitIncompleteType(self, ctx:RelayParser.IncompleteTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intType. + def visitIntType(self, ctx:RelayParser.IntTypeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#shapeSeq. + def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#parensShape. + def visitParensShape(self, ctx:RelayParser.ParensShapeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#intShape. + def visitIntShape(self, ctx:RelayParser.IntShapeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#typeIdent. + def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#body. + def visitBody(self, ctx:RelayParser.BodyContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarFloat. + def visitScalarFloat(self, ctx:RelayParser.ScalarFloatContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarInt. + def visitScalarInt(self, ctx:RelayParser.ScalarIntContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#scalarBool. + def visitScalarBool(self, ctx:RelayParser.ScalarBoolContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by RelayParser#ident. + def visitIdent(self, ctx:RelayParser.IdentContext): + return self.visitChildren(ctx) + + + +del RelayParser \ No newline at end of file diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 5f23e14d5559..dd0f54c664ca 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -17,324 +17,16 @@ # pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck """ -This file contains: -1. The set of passes for Relay, which exposes an interface for configuring the - passes and scripting them in Python. - -2. The pass manager for Relay which exposes different granularity of interfaces - for users to implement and use passes more conveniently. +This file contains the set of passes for Relay, which exposes an interface for +configuring the passes and scripting them in Python. """ -import types - from . import _ir_pass from . import _make from .expr import Expr from .ty import Type -from .base import RelayNode, register_relay_node from .module import Module -@register_relay_node -class PassInfo(RelayNode): - """The class that contains the meta data required by a pass. It is the - container of information needed by running an optimization or analysis. - This class can be extended by adding new members when more meta data is - needed. - - Parameters - ---------- - name : str - The pass name. - - opt_level : int - The optimization level of this pass. - - required : List[str] - The list of passes that are required by a certain pass. - """ - - def __init__(self, name, opt_level, required=None): - self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level, - required) - - -@register_relay_node -class PassContext(RelayNode): - """The basis where a Relay optimization/analysis runs on. - Each pass context contains a number of auxiliary information that is used - to help an optimization pass. Such information includes the error reporter - to record the errors of during the optimization, etc. - """ - - def __init__(self): - self.__init_handle_by_constructor__(_ir_pass.PassContext) - - -@register_relay_node -class Pass(RelayNode): - """The base class of all passes. All methods here are just simple wrappers - that are implemented in the backend. They are defined for users to - conveniently interact with the base class. - """ - - def set_pass_context(self, pass_ctx): - """Setup the pass context for analysis and optimizations. This context - could be shared by different passes for sequential passes. - - Parameters - ---------- - pass_ctx : PassContext - The context that is used to help perform a certain pass or a series - of passes. - """ - if not isinstance(pass_ctx, PassContext): - raise TypeError("pass_ctx is expected to be the PassContext type") - _ir_pass.SetContext(self, pass_ctx) - - @property - def info(self): - """Get the pass meta.""" - return _ir_pass.Info(self) - - def __call__(self, mod): - """Execute the pass. Note that for sequential pass, the dependency among - different passes will be resolved in the backend. - - Parameters - ---------- - mod : tvm.relay.Module - The module that a certain optimization is performed on. - - Returns - ------- - mod : tvm.relay.Module - The updated module after applying this pass. - """ - return _ir_pass.RunPass(self, mod) - - -@register_relay_node -class ModulePass(Pass): - """A pass that works on tvm.relay.Module. Users don't need to interact with - this class directly. Instead, a module pass should be created through - `module_pass`, because the design of the `module_pass` API is flexible - enough to handle the creation of a module pass in different manners. In - addition, all members of a module pass can be accessed from the base class. - The same rule applies to FunctionPass and SequentialPass as well. - """ - - -@register_relay_node -class FunctionPass(Pass): - """A pass that works on each tvm.relay.Function in a module. A function - pass class should be created through `function_pass`. - """ - - -@register_relay_node -class SequentialPass(Pass): - """A pass that works on a sequence of pass objects. A sequential pass class - should be created through `sequential_pass`. - """ - - -def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the module pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. - - Examples - -------- - The following code creates a module level pass and adds an abs function to - the module. - - .. code-block:: python - - @relay.ir_pass.module_pass(opt_level=2) - def transform(mod, ctx): - tp = relay.TensorType((10,), "float32") - x = relay.var("x", tp) - gv = relay.GlobalVar("var") - func = relay.Function([x], relay.abs(x)) - new_mod = relay.Module({gv: func}) - new_mod.update(mod) - return new_mod - - module_pass = transform - assert isinstance(module_pass, ir_pass.ModulePass) - assert module_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = module_pass(m) - # Now a function abs should be added to the module m. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the module pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_module_pass(pass_func): - """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateModulePass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_module_pass(pass_func) - return create_module_pass - - -def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created function pass using the - given optimization function. - - Parameters - ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. - - opt_level : int - The optimization level of this module pass. - - name : Optional[str] - The name of the function pass. The name could be empty. In this case, the - name of the optimization function will be used as the pass name. - - required : Optional[List[str]] - The list of passes that the module pass is dependent on. - - Returns - ------- - create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. - - Examples - -------- - The following code creates a function level pass that performs constant - folding. - - .. code-block:: python - - @relay.ir_pass.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) - - function_pass = transform - assert isinstance(function_pass, ir_pass.FunctionPass) - assert function_pass.info.opt_level == 2 - - # Given a module m, the optimization could be invoked as the follwoing: - updated_mod = function_pass(m) - # Now constant folding should have been applied to every function in - # the provided module m. And the updated module will be returned. - """ - - if opt_level is None: - raise ValueError("Please provide opt_level for the funtion pass.") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of " + - "list/tuple.") - - def create_function_pass(pass_func): - """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - return _ir_pass.CreateFunctionPass(pass_func, opt_level, - name if name else pass_func.__name__, - required) - - if pass_func: - return create_function_pass(pass_func) - return create_function_pass - - -def sequential_pass(passes=None, opt_level=2, name="sequential_pass", - required=None, disabled=None): - """Create a sequential pass using a defined optimization function from - Python. Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to - apply when running a sequential pass. Pass dependency will be resolved in - the backend as well. - - Parameters - ---------- - passes : Optional[List[Pass]] - A sequence of passes candidate for optimization. - - opt_level : Optional[int] - The optimization level of this sequential pass. - - name : Optional[str] - The name of the sequential pass. - - required : Optional[List[str]] - The list of passes that the sequential pass is dependent on. - - disabled : Optional[List[str]] - A list of disabled passes. - - Returns - ------- - ret : Pass - A sequential pass built through pass_func. - """ - - passes = passes if passes else [] - if not isinstance(passes, (list, tuple)): - raise TypeError("passes must be a list of Pass objects.") - - disabled = disabled if disabled else [] - if not isinstance(disabled, (list, tuple)): - raise TypeError("disabled must be a list or tuple of pass names") - - required = required if required else [] - if not isinstance(required, (list, tuple)): - raise TypeError("Required is expected to be the type of list/tuple.") - - return _ir_pass.CreateSequentialPass(passes, opt_level, name, required, - disabled) - - def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited @@ -437,7 +129,7 @@ def well_formed(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -483,7 +175,7 @@ def free_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -505,7 +197,7 @@ def bound_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -521,7 +213,7 @@ def all_vars(expr): Parameters ---------- - expr: tvm.relay.Expr + expr : tvm.relay.Expr The input expression Returns @@ -537,9 +229,10 @@ def free_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -556,9 +249,10 @@ def bound_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + + mod : Optional[tvm.relay.Module] The global module Returns @@ -575,9 +269,9 @@ def all_type_vars(expr, mod=None): Parameters ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] + expr : Union[tvm.relay.Expr,tvm.relay.Type] The input expression/type - mod: tvm.relay.Module, optional + mod : Optional[tvm.relay.Module] The global module Returns @@ -594,12 +288,12 @@ def simplify_inference(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with some simplification """ @@ -612,32 +306,34 @@ def canonicalize_ops(expr): Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression without bias_add """ return _ir_pass.canonicalize_ops(expr) -def dead_code_elimination(expr): +def dead_code_elimination(expr, inline_once=False): """ Remove expressions which does not effect the program result (dead code). Parameters ---------- - e: tvm.relay.Expr + expr : tvm.relay.Expr The input Expression + inline_once : Optional[Bool] + Whether to inline binding that occur only once. Returns ------- - result: tvm.relay.Expr + result : tvm.relay.Expr An expression which is semantically equal to the input expression, but with dead code removed. """ - return _ir_pass.dead_code_elimination(expr) + return _ir_pass.dead_code_elimination(expr, inline_once) def alpha_equal(lhs, rhs): @@ -645,15 +341,15 @@ def alpha_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is alpha equal to rhs. """ return bool(_make._alpha_equal(lhs, rhs)) @@ -667,15 +363,15 @@ def graph_equal(lhs, rhs): Parameters ---------- - lhs: tvm.relay.Expr + lhs : tvm.relay.Expr One of the input Expression. - rhs: tvm.relay.Expr + rhs : tvm.relay.Expr One of the input Expression. Returns ------- - result: bool + result : bool True iff lhs is data-flow equivalent to rhs. """ return bool(_make._graph_equal(lhs, rhs)) @@ -686,12 +382,12 @@ def structural_hash(value): Parameters ---------- - expr: tvm.relay.Expr or tvm.relay.Type + expr : Union[tvm.relay.Expr, tvm.relay.Type] The expression to hash. Returns ------- - result: int + result : int The hash value """ if isinstance(value, Expr): @@ -852,12 +548,12 @@ def to_a_normal_form(expr, mod=None): expr : tvm.relay.Expr The input expression. - mod: Optional[tvm.relay.Module] + mod : Optional[tvm.relay.Module] The global module. Returns ------- - expr: tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.to_a_normal_form(expr, mod) @@ -871,7 +567,7 @@ def to_graph_normal_form(expr): The input expression Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression """ return _ir_pass.to_graph_normal_form(expr) @@ -920,7 +616,7 @@ def get_total_mac_number(expr): Returns ------- - ret : int64 + result : int64 The number of MACs (multiply-accumulate) of a model """ return _ir_pass.GetTotalMacNumber(expr) @@ -935,17 +631,17 @@ def eliminate_common_subexpr(expr, fskip=None): expr : tvm.relay.Expr The input expression. - fskip: function + fskip : function The callback function that decides whether an expression should be skipped. Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ return _ir_pass.eliminate_common_subexpr(expr, fskip) -def partial_evaluate(expr): +def partial_evaluate(expr, mod=None): """ Evaluate the static fragment of the code. @@ -954,9 +650,30 @@ def partial_evaluate(expr): expr : tvm.relay.Expr The input expression. + mod : Optional[tvm.relay.Module] + The global module + Returns ------- - expr : tvm.relay.Expr + result : tvm.relay.Expr The output expression. """ - return _ir_pass.partial_evaluate(expr) + return _ir_pass.partial_evaluate(expr, mod) + +def unmatched_cases(match, mod=None): + """ + Finds cases that the match expression does not catch, if any. + + Parameters + ---------- + match : tvm.relay.Match + The match expression + mod : Optional[tvm.relay.Module] + The module (defaults to an empty module) + + Returns + ------- + missing_patterns : [tvm.relay.Pattern] + Patterns that the match expression does not catch. + """ + return _ir_pass.unmatched_cases(match, mod) diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 57e716534ee5..09746be13e30 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target): """Compute definition of argsort""" axis = get_const_int(attrs.axis) is_ascend = bool(get_const_int(attrs.is_ascend)) - dtype = str(attrs.dtype) - return [ - topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \ - dtype=dtype, flag=False) - ] + dtype = attrs.dtype + return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)] register_pattern("argsort", OpPattern.OPAQUE) + + +@register_schedule("topk") +def schedule_topk(_, outs, target): + """Schedule definition of argsort""" + with target: + return topi.generic.schedule_topk(outs) + + +@register_compute("topk") +def compute_topk(attrs, inputs, _, target): + """Compute definition of argsort""" + k = get_const_int(attrs.k) + axis = get_const_int(attrs.axis) + ret_type = attrs.ret_type + is_ascend = bool(get_const_int(attrs.is_ascend)) + dtype = attrs.dtype + out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype) + out = out if isinstance(out, list) else [out] + return out + + +register_pattern("topk", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index b97e3a8ce993..b7c9a79a8ad9 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("argmax", _schedule_reduce) _reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce) +_reg.register_schedule("all", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2eec6d03e7cd..95fb2ad18a25 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -23,6 +23,7 @@ schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective +schedule_concatenate = _reg.schedule_concatenate _reg.register_schedule("collapse_sum_like", _schedule_reduce) @@ -46,7 +47,7 @@ _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) _reg.register_schedule("stack", schedule_injective) -_reg.register_schedule("concatenate", schedule_injective) +_reg.register_schedule("concatenate", schedule_concatenate) _reg.register_schedule("_contrib_reverse_reshape", schedule_injective) _reg.register_schedule("gather_nd", schedule_injective) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 6451eb41aeb9..6f875919df4c 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,8 +17,9 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make +from ..expr import TupleWrapper -def argsort(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -37,7 +38,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): Whether to sort in ascending or descending order. dtype : string, optional - DType of the output indices. + The data type of the output indices. Returns ------- @@ -45,3 +46,42 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): Tensor with same shape as data. """ return _make.argsort(data, axis, is_ascend, dtype) + + +def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): + """Get the top k elements in an input tensor along the given axis. + + ret_type specifies the return type, can be one of ("both", "values", "indices"). + + Parameters + ---------- + data : relay.Expr + The input data tensor. + + k : int, optional + Number of top elements to select. Return all elements if k < 1. + + axis : int, optional + Axis long which to sort the input tensor. + + ret_type: str, optional + The return type [both, values, indices]. + "both": return both top k data and indices. + "values": return top k data only. + "indices": return top k indices only. + + is_ascend : boolean, optional + Whether to sort in ascending or descending order. + + dtype : string, optional + The data type of the indices output. + + Returns + ------- + out : relay.Expr or List[relay.Expr] + The computed result. + """ + out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) + if ret_type == "both": + return TupleWrapper(out, 2) + return out diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index b772c43e11cd..7bce9dd3c5b9 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -67,7 +67,7 @@ def conv2d(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -129,7 +129,7 @@ def conv2d_transpose(data, The weight expressions. strides : Tuple[int], optional - The strides of convoltution. + The strides of convolution. padding : Tuple[int], optional The padding of convolution on both sides of inputs. @@ -401,7 +401,7 @@ def upsampling(data, with data of shape (n, c, h, w) out will have a shape (n, c, h*scale, w*scale) - method indicates the algorithm to be used while calculating ghe out value + method indicates the algorithm to be used while calculating the out value and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") Parameters @@ -842,7 +842,7 @@ def contrib_conv2d_winograd_without_weight_transform(data, The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -908,7 +908,7 @@ def contrib_conv2d_winograd_nnpack_without_weight_transform(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -975,7 +975,7 @@ def contrib_conv2d_nchwc(data, The kernel expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -1040,7 +1040,7 @@ def contrib_depthwise_conv2d_nchwc(data, The kernel expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. @@ -1156,7 +1156,7 @@ def deformable_conv2d(data, The weight expressions. strides : tuple of int, optional - The strides of convoltution. + The strides of convolution. padding : tuple of int, optional The padding of convolution on both sides of inputs before convolution. diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 6ba207934d1b..906bf255d46e 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -219,6 +219,13 @@ def schedule_injective(attrs, outputs, target): with target: return topi.generic.schedule_injective(outputs) + +def schedule_concatenate(attrs, outputs, target): + """Generic schedule for concatinate.""" + with target: + return topi.generic.schedule_concatenate(outputs) + + __DEBUG_COUNTER__ = 0 def debug(expr, debug_func=None): diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 9d58a92041f3..0f2594600b0a 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False): return _make.sum(data, axis, keepdims, exclude) +def all(data, axis=None, keepdims=False, exclude=False): + """Computes the logical AND of boolean array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a sum is performed. The default, axis=None, + will sum all of the elements of the input array. If axis is + negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast + correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + data = relay.Constant(tvm.nd.array([[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]])) + + relay.all(data, axis=1) + # [[False, True, False], + # [False, False, False]] + + relay.all(data, axis=0) + # [[ True, False, False], + # [ True, True, False], + # [False, True, False]] + + """ + axis = [axis] if axis and isinstance(axis, int) else axis + return _make.all(data, axis, keepdims, exclude) + + def max(data, axis=None, keepdims=False, exclude=False): """ Computes the max of array elements over given axes. @@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 9c76b7e569dc..dce2258946cd 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -218,9 +218,10 @@ def take(data, indices, axis=None, mode="clip"): the flattened input array is used. mode : str, optional - Specifies how out-of-bound indices will behave. - clip - clip to the range (default) - wrap - wrap around the indices + Specifies how out-of-bound indices will behave [clip, wrap, fast]. + clip: clip to the range (default). + wrap: wrap around the indices. + fast: no clip or wrap around (user must make sure indices are in-bound). Returns ------- diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8c8c4cd9aaa3..7de118071aa4 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -82,7 +82,10 @@ def schedule_get_valid_counts(_, outs, target): def compute_get_valid_counts(attrs, inputs, _, target): """Compute definition of get_valid_counts""" score_threshold = get_const_float(attrs.score_threshold) - return topi.vision.get_valid_counts(inputs[0], score_threshold) + id_index = get_const_int(attrs.id_index) + score_index = get_const_int(attrs.score_index) + return topi.vision.get_valid_counts(inputs[0], score_threshold, + id_index, score_index) reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index ab34eb6e6cfb..d19dde306aca 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -20,7 +20,9 @@ from ...expr import TupleWrapper def get_valid_counts(data, - score_threshold): + score_threshold, + id_index=0, + score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -32,6 +34,12 @@ def get_valid_counts(data, score_threshold : optional, float Lower limit of score for valid bounding boxes. + id_index : optional, int + index of the class categories, -1 to disable. + + score_index: optional, int + Index of the scores/confidence of boxes. + Returns ------- valid_count : relay.Expr @@ -40,7 +48,8 @@ def get_valid_counts(data, out_tensor : relay.Expr Rearranged data tensor. """ - return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2) + return TupleWrapper(_make.get_valid_counts(data, score_threshold, + id_index, score_index), 2) def non_max_suppression(data, diff --git a/python/tvm/relay/parser.py b/python/tvm/relay/parser.py index 2d27b7b53f89..9218cae3de66 100644 --- a/python/tvm/relay/parser.py +++ b/python/tvm/relay/parser.py @@ -18,19 +18,6 @@ from __future__ import absolute_import from .. import register_func -def enabled(): - """Checks whether the parser is enabled, this allows users to - optionally support building the parser. - - We use this check before importing the parser. - """ - try: - # pylint: disable=unused-variable - from tvm.relay import _parser - return True - # pylint: disable=broad-except - except Exception: - return False @register_func("relay.fromtext") def fromtext(data, source_name=None): diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index ff823c3413fa..af0497e3801a 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -15,11 +15,16 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name -"""Adds certain standard global functions and ADT definitions to the module.""" +"""A prelude containing useful global functions and ADT definitions.""" +import os from .ty import GlobalTypeVar, TypeVar, FuncType, TupleType, scalar_type -from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem +from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const +from .op.tensor import add, subtract, equal from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard +from .parser import fromtext + +__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) class Prelude: """Contains standard definitions.""" @@ -61,39 +66,44 @@ def define_list_tl(self): cons_case = Clause(PatternConstructor(self.cons, [PatternVar(y), PatternVar(z)]), z) self.mod[self.tl] = Function([x], Match(x, [cons_case]), self.l(a), [a]) + def define_list_nth(self): """Defines a function to get the nth element of a list. - nth(l) : list[a] -> a + nth(l) : list[a] -> Tensor[(), int32] -> a """ self.nth = GlobalVar("nth") a = TypeVar("a") x = Var("x", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) + + body = If(equal(n, const(0)), + self.hd(x), + self.nth(self.tl(x), subtract(n, const(1)))) + + self.mod[self.nth] = Function([x, n], body, a, [a]) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.hd(x)) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), self.nth(self.tl(x), y)) - self.mod[self.nth] = Function([x, n], Match(n, [z_case, s_case]), a, [a]) def define_list_update(self): """Defines a function to update the nth element of a list and return the updated list. - update(l, i, v) : list[a] -> nat -> a -> list[a] + update(l, i, v) : list[a] -> Tensor[(), int32] -> a -> list[a] """ self.update = GlobalVar("update") a = TypeVar("a") l = Var("l", self.l(a)) - n = Var("n", self.nat()) + n = Var("n", scalar_type('int32')) v = Var("v", a) - y = Var("y") + body = If(equal(n, const(0)), + self.cons(v, self.tl(l)), + self.cons(self.hd(l), + self.update(self.tl(l), + subtract(n, const(1)), + v))) - z_case = Clause(PatternConstructor(self.z), self.cons(v, self.tl(l))) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.cons(self.hd(l), self.update(self.tl(l), y, v))) + self.mod[self.update] = Function([l, n, v], body, self.l(a), [a]) - self.mod[self.update] = Function([l, n, v], Match(n, [z_case, s_case]), self.l(a), [a]) def define_list_map(self): """Defines a function for mapping a function over a list's @@ -114,6 +124,7 @@ def define_list_map(self): self.cons(f(y), self.map(f, z))) self.mod[self.map] = Function([f, x], Match(x, [nil_case, cons_case]), self.l(b), [a, b]) + def define_list_foldl(self): """Defines a left-way fold over a list. @@ -136,6 +147,7 @@ def define_list_foldl(self): self.mod[self.foldl] = Function([f, av, bv], Match(bv, [nil_case, cons_case]), a, [a, b]) + def define_list_foldr(self): """Defines a right-way fold over a list. @@ -158,6 +170,7 @@ def define_list_foldr(self): self.mod[self.foldr] = Function([f, bv, av], Match(av, [nil_case, cons_case]), b, [a, b]) + def define_list_foldr1(self): """Defines a right-way fold over a nonempty list. @@ -196,6 +209,7 @@ def define_list_concat(self): self.foldr(updater, l2, l1), self.l(a), [a]) + def define_list_filter(self): """Defines a function that filters a list. @@ -214,6 +228,7 @@ def define_list_filter(self): If(f(h), self.cons(h, self.filter(f, t)), self.filter(f, t))) self.mod[self.filter] = Function([f, l], Match(l, [nil_case, cons_case]), self.l(a), [a]) + def define_list_zip(self): """Defines a function that combines two lists into a list of tuples of their elements. @@ -238,6 +253,7 @@ def define_list_zip(self): self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]), self.l(TupleType([a, b])), [a, b]) + def define_list_rev(self): """Defines a function that reverses a list. @@ -253,6 +269,7 @@ def define_list_rev(self): self.foldl(updater, self.nil(), l), self.l(a), [a]) + def define_list_map_accumr(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -282,6 +299,7 @@ def define_list_map_accumr(self): TupleType([a, self.l(c)]), [a, b, c]) + def define_list_map_accuml(self): """Defines an accumulative map, which is a fold that simulataneously updates an accumulator value and a list of results. @@ -321,6 +339,7 @@ def define_optional_adt(self): self.none = Constructor("none", [], self.optional) self.mod[self.optional] = TypeData(self.optional, [a], [self.some, self.none]) + def define_list_unfoldr(self): """Defines a function that builds up a list starting from a seed value. @@ -343,6 +362,7 @@ def define_list_unfoldr(self): self.mod[self.unfoldr] = Function([f, s], Match(f(s), [none_case, some_case]), self.l(b), [a, b]) + def define_list_unfoldl(self): """Defines a function that builds up a list starting from a seed value. @@ -362,52 +382,29 @@ def define_list_unfoldl(self): self.rev(self.unfoldr(f, s)), self.l(b), [a, b]) - def define_nat_adt(self): - """Defines a Peano (unary) natural number ADT. - Zero is represented by z(). s(n) adds 1 to a nat n.""" - self.nat = GlobalTypeVar("nat") - self.z = Constructor("z", [], self.nat) - self.s = Constructor("s", [self.nat()], self.nat) - self.mod[self.nat] = TypeData(self.nat, [], [self.z, self.s]) - - def define_nat_double(self): - """Defines a function that doubles a nat.""" - self.double = GlobalVar("double") - x = Var("x", self.nat()) - y = Var("y") - z_case = Clause(PatternConstructor(self.z), self.z()) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.s(self.s(self.double(y)))) - self.mod[self.double] = Function([x], Match(x, [z_case, s_case])) - - def define_nat_add(self): - """Defines a function that adds two nats.""" - self.add = GlobalVar("add") - x = Var("x", self.nat()) - y = Var("y", self.nat()) - a = Var("a") - z_case = Clause(PatternConstructor(self.z), y) - s_case = Clause(PatternConstructor(self.s, [PatternVar(a)]), - self.s(self.add(a, y))) - self.mod[self.add] = Function([x, y], Match(x, [z_case, s_case])) def define_list_sum(self): - """Defines a function that computes the sum of a list of nats.""" + """Defines a function that computes the sum of a list of integer scalars.""" self.sum = GlobalVar("sum") - a = Var("a", self.l(self.nat())) - self.mod[self.sum] = Function([a], self.foldl(self.add, self.z(), a)) + a = Var("a", self.l(scalar_type('int32'))) + x = Var('x') + y = Var('y') + addf = Function([x, y], add(x, y)) + self.mod[self.sum] = Function([a], self.foldl(addf, const(0), a)) + def define_list_length(self): - """Defines a function that returns the length of a list as a nat""" + """Defines a function that returns the length of a list""" self.length = GlobalVar("length") a = TypeVar("a") x = Var("x", self.l(a)) y = Var("y") - nil_case = Clause(PatternConstructor(self.nil), self.z()) + nil_case = Clause(PatternConstructor(self.nil), const(0)) cons_case = Clause(PatternConstructor(self.cons, [PatternWildcard(), PatternVar(y)]), - self.s(self.length(y))) + add(const(1), self.length(y))) self.mod[self.length] = Function([x], - Match(x, [nil_case, cons_case]), None, [a]) + Match(x, [nil_case, cons_case]), scalar_type('int32'), [a]) + def define_tree_adt(self): """Defines a tree ADT. A tree can contain any type. @@ -420,6 +417,7 @@ def define_tree_adt(self): self.rose = Constructor("rose", [a, self.l(self.tree(a))], self.tree) self.mod[self.tree] = TypeData(self.tree, [a], [self.rose]) + def define_tree_map(self): """Defines a function that maps over a tree. The function is applied to each subtree's contents. @@ -439,71 +437,58 @@ def define_tree_map(self): self.mod[self.tmap] = Function([f, t], Match(t, [rose_case]), self.tree(b), [a, b]) + def define_tree_size(self): - """Defines a function that computes the size of a tree as a nat. + """Defines a function that computes the size of a tree. - Signature: fn(t : tree[a]) -> nat + Signature: fn(t : tree[a]) -> Tensor[(), int32] """ self.size = GlobalVar("size") a = TypeVar("a") t = Var("t", self.tree(a)) - x = Var("x", self.tree(a)) z = Var("z") rose_case = Clause(PatternConstructor(self.rose, [PatternWildcard(), PatternVar(z)]), - self.s(self.sum(self.map(Function([x], self.size(x)), z)))) + add(const(1), self.sum(self.map(self.size, z)))) self.mod[self.size] = Function([t], - Match(t, [rose_case]), self.nat(), [a]) - - def define_id(self): - """Defines a function that return it's argument. - - Signature: fn(x : a) -> a - """ - self.id = GlobalVar("id") - a = TypeVar("a") - x = Var("x", a) - self.mod[self.id] = Function([x], x, a, [a]) - - - def define_compose(self): - """Defines a function that compose two function. - - Signature: fn(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c - """ - self.compose = GlobalVar("compose") - a = TypeVar("a") - b = TypeVar("b") - c = TypeVar("c") - f = Var("f", FuncType([b], c)) - g = Var("g", FuncType([a], b)) - x = Var("x") - self.mod[self.compose] = Function([f, g], - Function([x], f(g(x))), - FuncType([a], c), - [a, b, c]) + Match(t, [rose_case]), scalar_type('int32'), [a]) def define_iterate(self): - """Define a function that take a number n, a function f, - and return a closure that apply f n time on it's argument. + """Defines a function that take a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. - Signature: fn(n : nat, f : fn(a) -> a) -> fn(a) -> a + Signature: fn(f : fn(a) -> a, n : Tensor[(), int32]) -> fn(a) -> a """ self.iterate = GlobalVar("iterate") a = TypeVar("a") f = Var("f", FuncType([a], a)) - x = Var("x", self.nat()) - y = Var("y", self.nat()) - z_case = Clause(PatternConstructor(self.z), self.id) - s_case = Clause(PatternConstructor(self.s, [PatternVar(y)]), - self.compose(f, self.iterate(f, y))) + x = Var("x", scalar_type('int32')) + body = If(equal(x, const(0)), + self.id, + self.compose(f, + self.iterate(f, subtract(x, const(1))))) self.mod[self.iterate] = Function([f, x], - Match(x, [z_case, s_case]), + body, FuncType([a], a), [a]) + def load_prelude(self): + """ + Parses the portions of the Prelude written in Relay's text format and adds + them to the module. + """ + prelude_file = os.path.join(__PRELUDE_PATH__, "prelude.rly") + with open(prelude_file) as prelude: + prelude = fromtext(prelude.read()) + self.mod.update(prelude) + self.id = self.mod.get_global_var("id") + self.compose = self.mod.get_global_var("compose") + + def __init__(self, mod): self.mod = mod + self.load_prelude() self.define_list_adt() self.define_list_hd() self.define_list_tl() @@ -522,9 +507,6 @@ def __init__(self, mod): self.define_list_unfoldr() self.define_list_unfoldl() - self.define_nat_adt() - self.define_nat_double() - self.define_nat_add() self.define_list_length() self.define_list_nth() self.define_list_update() @@ -534,6 +516,4 @@ def __init__(self, mod): self.define_tree_map() self.define_tree_size() - self.define_id() - self.define_compose() self.define_iterate() diff --git a/python/tvm/relay/prelude.rly b/python/tvm/relay/prelude.rly new file mode 100644 index 000000000000..35c794a6d479 --- /dev/null +++ b/python/tvm/relay/prelude.rly @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +v0.0.2 + +def @id[a](%x: a) -> a { + %x +} + +def @compose[a, b, c](%f: fn(b) -> c, %g: fn(a) -> b) { + fn (%x: a) -> c { + %f(%g(%x)) + } +} diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index e52ce142e5c3..61e895ac7efb 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -22,7 +22,7 @@ import topi from . import _quantize from .quantize import QAnnotateKind, current_qconfig -from .quantize import _conv_counter, _set_conv_counter +from .quantize import annotate_context from .. import expr as _expr from .. import op as _op from ..op import op as _reg @@ -116,7 +116,6 @@ def frewrite_with_guard(ref_call, new_args, ctx): return _register(frewrite) if frewrite is not None else _register -@register_func("relay.quantize.attach_simulated_quantize") def attach_simulated_quantize(data, kind, sign=True, rounding="round"): """Attach a simulated quantize operation after input data expr. @@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data + actx = annotate_context() + key = tuple([data, kind, sign, rounding]) + if key in actx.qnode_map: + return actx.qnode_map[key] + dom_scale = _expr.var("dom_scale") clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") - return _quantize.simulated_quantize( + qnode = _quantize.simulated_quantize( data, dom_scale, clip_min, clip_max, kind, sign, rounding) + actx.qnode_map[key] = qnode + return qnode + +register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -152,11 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): """Rewrite function for conv2d. Lhs of conv will be quantized to input field, and rhs of conv will be quantized to weight field. Output would be in activation field""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: - _set_conv_counter(cnt + 1) - return None - _set_conv_counter(cnt + 1) + actx = annotate_context() + if current_qconfig().skip_conv_layers is not None: + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if actx.conv2d_counter() in skipped_indices: + actx.count_conv2d() + return None + actx.count_conv2d() lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -168,16 +178,26 @@ def conv2d_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +def check_to_skip(): + """Check the index of conv2d layer to decide whether to skip the current operator.""" + if current_qconfig().skip_conv_layers is not None: + skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers] + if annotate_context().conv2d_counter() - 1 in skipped_indices: + return True + return False + + @register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" - cnt = _conv_counter() - if cnt < current_qconfig().skip_k_conv: + if check_to_skip(): return None + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) @@ -194,7 +214,7 @@ def dense_rewrite(ref_call, new_args, ctx): @register_annotate_function("multiply") def multiply_rewrite(ref_call, new_args, ctx): """Rewrite function for multiply.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + if check_to_skip(): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -216,7 +236,7 @@ def multiply_rewrite(ref_call, new_args, ctx): @register_annotate_function("add") def add_rewrite(ref_call, new_args, ctx): """Rewrite function for add.""" - if _conv_counter() <= current_qconfig().skip_k_conv: + if check_to_skip(): return None lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) @@ -242,9 +262,24 @@ def add_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +@register_annotate_function("stop_fusion") +def stop_fusion_rewrite(ref_call, new_args, ctx): + """Rewrite function for add.""" + if check_to_skip(): + return None + + x_expr, x_kind = _get_expr_kind(new_args[0]) + if x_kind is None: + return None + + ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT) + ret_expr = _forward_op(ref_call, [ret_expr]) + return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT) + + def identity_rewrite(ref_call, new_args, ctx): """Simply forward the original operation""" - if _conv_counter() <= current_qconfig().skip_k_conv: + if check_to_skip(): return None x_expr, x_kind = _get_expr_kind(new_args[0]) @@ -255,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(ret_expr, x_kind) +register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) register_annotate_function("nn.avg_pool2d", identity_rewrite) @@ -262,8 +298,9 @@ def identity_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx): """Rewrite function for max pool2d""" - if _conv_counter() <= current_qconfig().skip_k_conv: + if check_to_skip(): return None + expr, x_kind = _get_expr_kind(new_args[0]) if x_kind is None: @@ -280,7 +317,7 @@ def pool2d_rewrite(ref_call, new_args, ctx): @register_annotate_function("concatenate") def concatenate_rewrite(ref_call, new_args, ctx): """Rewrite function for concatenate""" - if _conv_counter() <= current_qconfig().skip_k_conv: + if check_to_skip(): return None input_tuple = new_args[0] diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 607ee1821c86..a7749d4892fb 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -21,8 +21,9 @@ from . import _quantize from .. import expr as _expr +from .. import module as _module from .. import ir_pass as _ir_pass -from .. import build_module as _build +from .. import transform as _transform from .. import op as _op from ... import make as _make from ..base import NodeBase, register_relay_node @@ -70,11 +71,10 @@ class QConfig(NodeBase): "dtype_weight": "int8", "dtype_activation": "int32", "global_scale": 8.0, - "skip_k_conv": 1, + "skip_conv_layers": [0], "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, - "use_stop_fusion": True } # pylint: disable=no-member @@ -136,8 +136,9 @@ def qconfig(**kwargs): global_scale: float The global scale for calibration. - skip_k_conv: int - The number of skipped conv2d. + skip_conv_layers: list + Specifying which layers to be skipped. Provide a list of indices + that indicate which conv2d layers to leave untouched. round_for_shift: boolean Whether to add bias for rounding during shift. @@ -146,9 +147,10 @@ def qconfig(**kwargs): Whether to store low-bit integer back as output before dequantizing. Some accelerators need this, e.g. VTA. - use_stop_fusion: boolean - Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit - results to be stored in memory. + debug_enabled_ops: None or list of str + Partially quantize specified operators for debugging. The default value + is None, which means will try to call all operartors' annotate rewrite + function. Returns ------- @@ -160,40 +162,38 @@ def qconfig(**kwargs): return _make.node("relay.quantize.QConfig", **node_args) -CONV_COUNTER = 0 +class AnnotateContext(object): + """A global singleton annotate scope""" + Current = None + def __init__(self): + self.qnode_map = dict() + self._conv2d_counter = 0 -def _conv_counter(): - """Get the global counter for conv2d.""" - return CONV_COUNTER + def __enter__(self): + self._conv2d_counter = 0 + return self + def conv2d_counter(self): + """Get the counter for conv2d.""" + return self._conv2d_counter -def _set_conv_counter(n): - """Set the value of the global conv2d counter.""" - global CONV_COUNTER - CONV_COUNTER = n + def count_conv2d(self): + """Increase the value of the conv2d counter by one.""" + self._conv2d_counter += 1 + def __exit__(self, ptype, value, traceback): + pass -def annotate(graph): - """Given a float32 graph, annotate will rewrite the graph - and return back a graph which simulates the error brought by - current quantization scheme. - Parameters - --------- - graph: Function - The original graph - - Returns - ------- - ret: Function - The graph after annotation - """ - _set_conv_counter(0) # reset counter - return _quantize.annotate(graph) +def annotate_context(): + """Get the global singleton scope""" + if AnnotateContext.Current is None: + AnnotateContext.Current = AnnotateContext() + return AnnotateContext.Current -def calibrate(graph, dataset=None): +def calibrate(graph, mod=None, ctx=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -203,8 +203,11 @@ def calibrate(graph, dataset=None): graph: Function The simulation graph after annotation. - dataset: list of dict of Var -> NDArray - The calibration dataset. + mod: tvm.relay.Module + The module where calibration happens on. + + ctx: tvm.relay.PassContext + The pass context used for calibration. Returns ------- @@ -249,24 +252,52 @@ def _make_const(val): return _expr.bind(graph, const_params) -def realize(graph): - """The realize pass will transform the simulated quantized - graph, which computes with float32 actually, to a real low-bit - integer graph. It will replace the simulated_quantize with - several fine-grained operators like add, multiply, and shift - as more as possible for performance (fusion, etc.) +def annotate(): + """Given a float32 graph, this pass will rewrite the graph and return + a graph which simulates the error brought by the current quantization + scheme. - Parameters - --------- - graph: Function - The simulated graph after calibrating. + Returns + ------- + ret: tvm.relay.Pass + The registered pass for quantization annotation. + """ + return _quantize.QuantizeAnnotate() + + +def realize(): + """The realize pass will transform the simulated quantized graph, which + actually computes with float32, to a real low-bit integer graph. It will + replace the `simulated_quantize` with several fine-grained operators like + add, multiply, and shift as much as possible for better performance. Returns ------- - ret: Function - The graph after realization + ret: tvm.relay.Pass + The registered pass for quantization realization. + """ + return _quantize.QuantizeRealize() + + +def _bind_params(func, params): + """Bind the params to the expression. """ - return _quantize.realize(graph) + name_dict = {} + for arg in func.params: + name = arg.name_hint + if name in name_dict: + name_dict[name] = None + else: + name_dict[name] = arg + bind_dict = {} + for k, v in params.items(): + if k not in name_dict: + continue + arg = name_dict[k] + if arg is None: + raise ValueError("Multiple args in the function have name %s" % k) + bind_dict[arg] = _expr.const(v) + return _expr.bind(func, bind_dict) def quantize(graph, params=None, dataset=None): @@ -292,15 +323,29 @@ def quantize(graph, params=None, dataset=None): ret: Function The graph after quantization """ - opt_passes = ["SimplifyInference", - "FoldScaleAxis", - "FoldConstant", - "CanonicalizeOps"] - with _build.build_config(add_pass=opt_passes): - graph = _build.optimize(graph, params=params) - - graph = annotate(graph) - graph = calibrate(graph, dataset) - graph = realize(graph) - graph = _ir_pass.fold_constant(graph) - return graph + if params: + graph = _bind_params(graph, params) + + mod = _module.Module.from_expr(graph) + # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and + # "CanonicalizeOps" optimization before quantization. + optimize = _transform.Sequential([_transform.SimplifyInference(), + _transform.FoldConstant(), + _transform.FoldScaleAxis(), + _transform.CanonicalizeOps(), + _transform.FoldConstant()]) + + calibrate_pass = _transform.function_pass(calibrate, opt_level=1, + name="QuantizeCalibrate") + quantize_seq = _transform.Sequential([annotate(), + calibrate_pass, + realize(), + _transform.FoldConstant()]) + with annotate_context(): + with _transform.PassContext(opt_level=3, + required_pass=["QuantizeAnnotate", + "QuantizeCalibrate", + "QuantizeRealize"]): + mod = optimize(mod) + mod = quantize_seq(mod) + return mod[mod.entry_func.name_hint] diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index b4a8394e2659..7a5007bbfb8f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -27,6 +27,8 @@ from . import squeezenet from . import vgg from . import densenet +from . import yolo_detection from .config import ctx_list from .init import create_workload +from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr diff --git a/nnvm/python/nnvm/testing/darknet.py b/python/tvm/relay/testing/darknet.py similarity index 100% rename from nnvm/python/nnvm/testing/darknet.py rename to python/tvm/relay/testing/darknet.py diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py new file mode 100644 index 000000000000..a76a340f113d --- /dev/null +++ b/python/tvm/relay/testing/nat.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Defines a unary natural number (Peano natural number) abstract +data type for Relay and provides some utility functions for it. +Nats are useful for testing purposes, as they make it easy to write +test cases for recursion and pattern matching.""" + +from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay.expr import Var, Function, GlobalVar +from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType + +def define_nat_adt(prelude): + """Defines a Peano (unary) natural number ADT. + Zero is represented by z(). s(n) adds 1 to a nat n. + Adds the fields nat, z, and s to the preluide, representing + (respectively) the nat ADT and the z and s constructors. + """ + prelude.nat = GlobalTypeVar("nat") + prelude.z = Constructor("z", [], prelude.nat) + prelude.s = Constructor("s", [prelude.nat()], prelude.nat) + prelude.mod[prelude.nat] = TypeData(prelude.nat, [], [prelude.z, prelude.s]) + + +def define_nat_double(prelude): + """Defines a function that doubles a nat. Adds a field called + 'double' to the prelude, giving the GlobalVar pointing to + the function. + """ + prelude.double = GlobalVar("double") + x = Var("x", prelude.nat()) + y = Var("y") + z_case = Clause(PatternConstructor(prelude.z), prelude.z()) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.s(prelude.s(prelude.double(y)))) + prelude.mod[prelude.double] = Function([x], Match(x, [z_case, s_case])) + + +def define_nat_add(prelude): + """Defines a function that adds two nats and adds a field to the + prelude 'add' giving the GlobalVar pointing to that function. + """ + prelude.add = GlobalVar("add") + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + a = Var("a") + z_case = Clause(PatternConstructor(prelude.z), y) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(a)]), + prelude.s(prelude.add(a, y))) + prelude.mod[prelude.add] = Function([x, y], Match(x, [z_case, s_case])) + + +# versions of prelude functions that use nats instead of scalars + +def define_nat_nth(prelude): + """Defines a function to get the nth eleemnt of a list using + a nat to index into the list. + + nat_nth(l, n): fun(list[a], nat) -> a + """ + prelude.nat_nth = GlobalVar("nat_nth") + a = TypeVar("a") + x = Var("x", prelude.l(a)) + n = Var("n", prelude.nat()) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), prelude.hd(x)) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.nat_nth(prelude.tl(x), y)) + + prelude.mod[prelude.nat_nth] = Function([x, n], + Match(n, [z_case, s_case]), + a, [a]) + + +def define_nat_update(prelude): + """Defines a function to update the nth element of a list and return the updated list. + + nat_update(l, i, v) : fun(list[a], nat, a) -> list[a] + """ + prelude.nat_update = GlobalVar("nat_update") + a = TypeVar("a") + # pylint: disable=invalid-name + l = Var("l", prelude.l(a)) + n = Var("n", prelude.nat()) + v = Var("v", a) + y = Var("y") + + z_case = Clause(PatternConstructor(prelude.z), + prelude.cons(v, prelude.tl(l))) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.cons( + prelude.hd(l), + prelude.nat_update(prelude.tl(l), y, v))) + + prelude.mod[prelude.nat_update] = Function([l, n, v], + Match(n, [z_case, s_case]), + prelude.l(a), [a]) + + +def define_nat_iterate(prelude): + """Defines a function that takes a number n and a function f; + returns a closure that takes an argument and applies f + n times to its argument. + + Signature: fn(fn(a) -> a, nat) -> fn(a) -> a + """ + prelude.nat_iterate = GlobalVar("nat_iterate") + a = TypeVar("a") + f = Var("f", FuncType([a], a)) + x = Var("x", prelude.nat()) + y = Var("y", prelude.nat()) + + z_case = Clause(PatternConstructor(prelude.z), prelude.id) + s_case = Clause(PatternConstructor(prelude.s, [PatternVar(y)]), + prelude.compose(f, prelude.nat_iterate(f, y))) + + prelude.mod[prelude.nat_iterate] = Function([f, x], + Match(x, [z_case, s_case]), + FuncType([a], a), + [a]) + + +def add_nat_definitions(prelude): + """Given a Relay prelude, adds a Peano nat ADT, as well as functions + for adding nats and doubling nats. It also adds versions of + update, nth, and iterate that take nats instead of scalars (the + names are prefixed with 'nat_').""" + define_nat_adt(prelude) + define_nat_double(prelude) + define_nat_add(prelude) + define_nat_nth(prelude) + define_nat_update(prelude) + define_nat_iterate(prelude) + + +# helper functions for working with nats + + +def count(prelude, n): + """Takes a ConstructorValue corresponding to a nat ADT + and converts it into a Python integer. This is an example of + using an ADT value in Python. + """ + assert isinstance(n, ConstructorValue) + if n.tag == prelude.z.tag: + return 0 + assert n.tag == prelude.s.tag + return 1 + count(prelude, n.fields[0]) + + +def make_nat_value(prelude, n): + """The inverse of count(): Given a non-negative Python integer, + constructs a ConstructorValue representing that value as a nat. + """ + if n == 0: + return ConstructorValue(prelude.z.tag, [], None, []) + return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, []) + + +def make_nat_expr(prelude, n): + """Given a non-negative Python integer, constructs a Python + expression representing that integer's value as a nat. + """ + assert n >= 0 + ret = prelude.z() + while n > 0: + ret = prelude.s(ret) + n = n - 1 + return ret diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index d82ed0f46097..a56e6fe1782d 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -163,13 +163,10 @@ def get_workload_official(model_url, model_sub_path): model_sub_path: Sub path in extracted tar for the ftozen protobuf file. - temp_dir: TempDirectory - The temporary directory object to download the content. - Returns ------- - graph_def: graphdef - graph_def is the tensorflow workload for mobilenet. + model_path: str + Full path to saved model file """ @@ -200,7 +197,7 @@ def get_workload(model_path, model_sub_path=None): Returns ------- graph_def: graphdef - graph_def is the tensorflow workload for mobilenet. + graph_def is the tensorflow workload. """ diff --git a/nnvm/python/nnvm/testing/yolo_detection.py b/python/tvm/relay/testing/yolo_detection.py similarity index 100% rename from nnvm/python/nnvm/testing/yolo_detection.py rename to python/tvm/relay/testing/yolo_detection.py diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py new file mode 100644 index 000000000000..5f47e5b446aa --- /dev/null +++ b/python/tvm/relay/transform.py @@ -0,0 +1,725 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" +Relay pass transformation infrastructure. +""" +import types +import inspect +import functools + +from tvm._ffi.runtime_ctypes import TVMContext +from . import _transform +from .base import RelayNode, register_relay_node +from .. import nd as _nd + + +@register_relay_node +class PassInfo(RelayNode): + """The class that contains the meta data required by a pass. It is the + container of information needed by running an optimization or analysis. + This class can be extended by adding new members when more meta data is + needed. + + Parameters + ---------- + opt_level : int + The optimization level of this pass. + + name : str + The pass name. + + required : List[str] + The list of passes that are required by a certain pass. + """ + + def __init__(self, opt_level, name, required=None): + self.__init_handle_by_constructor__( + _transform.PassInfo, opt_level, name, required) + + +@register_relay_node +class PassContext(RelayNode): + """The basis where a Relay optimization/analysis runs on. + Each pass context contains a number of auxiliary information that is used + to help an optimization pass. Such information includes the error reporter + to record the errors of during the optimization, etc. + + opt_level : Optional[int] + The optimization level of this pass. + + fallback_device : Optional[Union[int, str, TVMContext]] + The fallback device type. It is also used as the default device for + operators that are not annotated during heterogeneous execution. + + required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are required by a certain pass. + + disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] + The list of passes that are disabled. + """ + def __init__(self, + opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + if isinstance(fallback_device, str): + fallback_device = _nd.context(fallback_device).device_type + elif isinstance(fallback_device, TVMContext): + fallback_device = fallback_device.device_type + if not isinstance(fallback_device, int): + raise TypeError("required_pass is expected to be the type of " + + "int/str/TVMContext.") + + required = list(required_pass) if required_pass else [] + if not isinstance(required, (list, tuple)): + raise TypeError("required_pass is expected to be the type of " + + "list/tuple/set.") + + disabled = list(disabled_pass) if disabled_pass else [] + if not isinstance(disabled, (list, tuple)): + raise TypeError("disabled_pass is expected to be the type of " + + "list/tuple/set.") + + self.__init_handle_by_constructor__(_transform.PassContext, opt_level, + fallback_device, required, + disabled) + + def __enter__(self): + _transform.EnterPassContext(self) + return self + + def __exit__(self, ptype, value, trace): + _transform.ExitPassContext(self) + + @staticmethod + def current(): + """Return the current pass context.""" + return _transform.GetCurrentPassContext() + + +def build_config(opt_level=2, + fallback_device=_nd.cpu(), + required_pass=None, + disabled_pass=None): + """Configure the build behavior by setting config variables. + + Parameters + ---------- + opt_level: int, optional + Optimization level. The optimization pass name and level are as the + following: + + .. code-block:: python + + OPT_PASS_LEVEL = { + "SimplifyInference": 0, + "OpFusion": 1, + "FoldConstant": 2, + "CombineParallelConv2D": 3, + "FoldScaleAxis": 3, + "AlterOpLayout": 3, + "CanonicalizeOps": 3, + "EliminateCommonSubexpr": 3, + } + + fallback_device : int, str, or tvm.TVMContext, optional + The fallback device. It is also used as the default device for + operators without specified device during heterogeneous execution. + + required_pass: set of str, optional + Optimization passes that are required regardless of optimization level. + + disabled_pass: set of str, optional + Optimization passes to be disabled during optimization. + + Returns + ------- + pass_context: PassContext + The pass context for optimizations. + """ + return PassContext(opt_level, fallback_device, required_pass, + disabled_pass) + + +@register_relay_node +class Pass(RelayNode): + """The base class of all passes. All methods here are just simple wrappers + that are implemented in the backend. They are defined for users to + conveniently interact with the base class. + """ + + @property + def info(self): + """Get the pass meta.""" + return _transform.Info(self) + + def __call__(self, mod): + """Execute the pass. Note that for sequential pass, the dependency among + different passes will be resolved in the backend. + + Parameters + ---------- + mod : tvm.relay.Module + The module that a certain optimization is performed on. + + Returns + ------- + mod : tvm.relay.Module + The updated module after applying this pass. + """ + return _transform.RunPass(self, mod) + + +@register_relay_node +class ModulePass(Pass): + """A pass that works on tvm.relay.Module. Users don't need to interact with + this class directly. Instead, a module pass should be created through + `module_pass`, because the design of the `module_pass` API is flexible + enough to handle the creation of a module pass in different manners. In + addition, all members of a module pass can be accessed from the base class. + The same rule applies to FunctionPass as well. + """ + + +@register_relay_node +class FunctionPass(Pass): + """A pass that works on each tvm.relay.Function in a module. A function + pass class should be created through `function_pass`. + """ + + +@register_relay_node +class Sequential(Pass): + """A pass that works on a sequence of pass objects. Multiple passes can be + executed sequentially using this class. + + Some typical usage of the sequential pass are: + 1. Users provide a list of passes for optimization. + 2. Only an optimization level is provided so that the backend system has + to glob all passes at this level and below to perform the optimizations. + Note that users can also provide a series of passes that they don't want to + apply when running a sequential pass. Pass dependency will be resolved in + the backend as well. + + Parameters + ---------- + passes : Optional[List[Pass]] + A sequence of passes candidate for optimization. + + opt_level : Optional[int] + The optimization level of this sequential pass. + + name : Optional[str] + The name of the sequential pass. + + required : Optional[List[str]] + The list of passes that the sequential pass is dependent on. + """ + + def __init__(self, + passes=None, + opt_level=2, + name="sequential", + required=None): + passes = passes if passes else [] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a list of Pass objects.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of list/tuple.") + + self.__init_handle_by_constructor__(_transform.Sequential, + passes, opt_level, name, required) + + +def InferType(): + """Infer the type of an expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered type inference pass. + """ + return _transform.InferType() + + +def FoldScaleAxis(): + """Fold the scaling of axis into weights of conv2d/dense. This pass will + invoke both forward and backward scale folding. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to fold expressions. + + Note + ---- + Internally, we will call backward_fold_scale_axis before using + forward_fold_scale_axis. As backward folding targets common conv-bn + pattern. + """ + return _transform.FoldScaleAxis() + + +def SimplifyInference(): + """Simplify the data-flow graph for inference phase. An simplified expression + which is semantically equal to the input expression will be returned. + + Returns + ------- + ret: tvm.relay.Pass + The registered to perform operator simplification. + """ + return _transform.SimplifyInference() + + +def CanonicalizeOps(): + """ Canonicalize special operators to basic operators. + This can simplify followed analysis. (e.g. expanding bias_add to + expand_dims and broadcast_add.) + + Returns + ------- + ret: tvm.relay.Pass + The registered pass performing the canonicalization. + """ + return _transform.CanonicalizeOps() + + +def DeadCodeElimination(): + """ Remove expressions which does not effect the program result (dead code). + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that eliminates the dead code in a Relay program. + """ + return _transform.DeadCodeElimination() + + +def FoldConstant(): + """Fold the constant expression in expr. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for constant folding. + """ + return _transform.FoldConstant() + + +def FuseOps(fuse_opt_level=-1): + """Fuse operators in an expr to a larger operator according to some rules. + + Parameters + ---------- + fuse_opt_level : int + The level of fuse optimization. -1 indicates that the level will be + inferred from pass context. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass for operator fusion. + """ + return _transform.FuseOps(fuse_opt_level) + + +def CombineParallelConv2D(min_num_branches=3): + """Combine multiple conv2d operators into one. + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel conv2d operators. + """ + return _transform.CombineParallelConv2D(min_num_branches) + + +def AlterOpLayout(): + """Alternate the layouts of operators or replace primitive operators with + other expressions. + This pass can be used for computing convolution in custom layouts or + other general weight pre-transformation. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that alters the layout of operators. + """ + return _transform.AlterOpLayout() + + +def RewriteAnnotatedOps(fallback_device): + """Rewrite the annotated program where annotation operators, e.g. + `on_deivce`, mark which device an expression should be scheduled to. + This pass helps heterogeneous execution where different operators may need + to be allocated on various devices. + + Parameters + ---------- + fallback_device : int + The fallback device type. It is also used as the default device for + operators with no annotated device. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that rewrites an expression with annotated + `on_device` operators. + """ + return _transform.RewriteDeviceAnnotation(fallback_device) + + +def ToANormalForm(): + """Turn Graph Normal Form expression into A Normal Form Expression. + The scope of the root expression is the global scope. + The scope of any non root expression is the least common ancestor of all it's scope. + Values are ordered by post-DFS order in each scope. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that transforms an expression into A Normal Form. + """ + return _transform.ToANormalForm() + +def EtaExpand(): + """Add abstraction over a function + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that eta expands an expression. + """ + return _transform.EtaExpand() + +def ToGraphNormalForm(): + """Turn A Normal Form expression into Graph Normal Form expression + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that transforms an expression into Graph Normal Form. + """ + return _transform.ToGraphNormalForm() + + +def EliminateCommonSubexpr(fskip=None): + """Eliminate common subexpressions. + + Parameters + ---------- + fskip: Callable + The callback function that decides whether an expression should be + skipped. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that eliminates common subexpressions. + """ + return _transform.EliminateCommonSubexpr(fskip) + + +def PartialEvaluate(): + """Evaluate the static fragment of the code. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that performs partial evaluation on an expression. + """ + return _transform.PartialEvaluate() + +def CanonicalizeCast(): + """ + Canonicalize cast expressions to make operator fusion more efficient. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass that canonicalizes cast expression. + """ + return _transform.CanonicalizeCast() + +def _wrap_class_module_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyModulePass(ModulePass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(mod, ctx): + return inst.transform_module(mod, ctx) + self.__init_handle_by_constructor__( + _transform.MakeModulePass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__) + PyModulePass.__name__ = pass_cls.__name__ + PyModulePass.__doc__ = pass_cls.__doc__ + PyModulePass.__module__ = pass_cls.__module__ + return PyModulePass + + +def module_pass(pass_func=None, opt_level=None, name=None, required=None): + """Decorate a module pass. + + This function returns a callback when pass_func is provided. + Otherwise, it serves a decorator function. + + pass_func can also be a class type with a method transform_module. + This function will create a decorated ModulePass using transform_module + as the pass function. + + Parameters + ---------- + pass_func : Optional[Callable[(Module, PassContext) ->Module]] + The transformation function or class. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the module pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_module_pass : Union[Callable, ModulePass] + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new ModulePass will be returned when we decorate a pass function. + A new ModulePass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a module pass class. + + .. code-block:: python + + @relay.transform.module_pass + class CustomPipeline: + def __init__(self, enable_fold): + self.enable_fold = enable_fold + self.cse = relay.transform.EliminateCommonSubexpr() + self.const_fold = relay.transform.FoldConstant() + + def transform_module(self, mod, ctx): + mod = self.cse(mod, ctx) + if self.enable_fold: + mod = self.const_fold(mod, ctx) + return mod + + # create an instance of customized pipeline + pipeline = CustomPipeline(enable_fold=False) + assert isinstance(pipeline, transform.ModulePass) + # run the pipeline. + output_module = pipeline(input_module) + + The following code creates a module pass by decorating + a user defined transform function. + + .. code-block:: python + + @relay.transform.module_pass(opt_level=2) + def transform(mod, ctx): + tp = relay.TensorType((10,), "float32") + x = relay.var("x", tp) + gv = relay.GlobalVar("var") + func = relay.Function([x], relay.abs(x)) + new_mod = relay.Module({gv: func}) + new_mod.update(mod) + return new_mod + + module_pass = transform + assert isinstance(module_pass, transform.ModulePass) + assert module_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = module_pass(m) + # Now a function abs should be added to the module m. + """ + if opt_level is None: + raise ValueError("Please provide opt_level for the module pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_module_pass(pass_arg): + """Internal function that creates a module pass""" + fname = name if name else pass_arg.__name__ + info = PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_module_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _transform.MakeModulePass(pass_arg, info) + + if pass_func: + return create_module_pass(pass_func) + return create_module_pass + + +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + self.__init_handle_by_constructor__( + _transform.MakeFunctionPass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + +def function_pass(pass_func=None, opt_level=None, name=None, required=None): + """Decorate a function pass. + + This function returns a callback when pass_func + is provided. Otherwise, it returns the created function pass using the + given optimization function. + + Parameters + ---------- + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. + + opt_level : int + The optimization level of this module pass. + + name : Optional[str] + The name of the function pass. The name could be empty. In this case, the + name of the optimization function will be used as the pass name. + + required : Optional[List[str]] + The list of passes that the module pass is dependent on. + + Returns + ------- + create_function_pass : Union[Callable, FunctionPass] + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. + + Examples + -------- + The following code block decorates a function pass class. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + x = relay.var("x", shape=(10, 20)) + f1 = relay.Function([x], x) + f2 = relay.Function([x], relay.log(x)) + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in input_mod is replaced by f1 + res_mod = fpass(input_mod) + + + The following code creates a function pass by decorating + a user defined transform function. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=2) + def transform(func, mod, ctx): + # my transformations here. + return func + + function_pass = transform + assert isinstance(function_pass, transform.FunctionPass) + assert function_pass.info.opt_level == 2 + + # Given a module m, the optimization could be invoked as the follwoing: + updated_mod = function_pass(m) + # Now constant folding should have been applied to every function in + # the provided module m. And the updated module will be returned. + """ + + if opt_level is None: + raise ValueError("Please provide opt_level for the funtion pass.") + + required = required if required else [] + if not isinstance(required, (list, tuple)): + raise TypeError("Required is expected to be the type of " + + "list/tuple.") + + def create_function_pass(pass_arg): + """Internal function that creates a function pass""" + fname = name if name else pass_arg.__name__ + info = PassInfo(opt_level, fname, required) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _transform.MakeFunctionPass(pass_arg, info) + + if pass_func: + return create_function_pass(pass_func) + return create_function_pass diff --git a/python/tvm/relay/transform.pyi b/python/tvm/relay/transform.pyi new file mode 100644 index 000000000000..343e89976b09 --- /dev/null +++ b/python/tvm/relay/transform.pyi @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from .base import NodeBase + + +class PassContext(NodeBase): + def __init__(self): + ... + +class PassInfo(NodeBase): + name = ... # type: str + opt_level = ... # type: int + required = ... # type: list + + def __init__(self, name, opt_level, required) + # type: (str, int, list) -> None + + +class Pass(NodeBase): + def __init__(self): + ... + + +class ModulePass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class FunctionPass(Pass): + name = ... # type: str + opt_level = ... # type: int + pass_func = ... # type: Callable + required = ... # type: list + + def __init__(self, name, opt_level, pass_func, required): + # type: (str, int, Callable, list) -> None + ... + + +class Sequential(Pass): + name = ... # type: str + opt_level = ... # type: int + passes = ... # type: list + required = ... # type: list + disabled = ... # type: list + + def __init__(self, name, opt_level, passes, required, disabled): + # type: (str, int, list, list, list) -> None + ... diff --git a/python/tvm/target.py b/python/tvm/target.py index eff0088b37ce..828fff8e228c 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -133,7 +133,7 @@ def __enter__(self): return self def __exit__(self, ptype, value, trace): - _api_internal._ExitTargetScope() + _api_internal._ExitTargetScope(self) @register_node diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index ce7cbae385d9..db8fb272a551 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -147,7 +147,7 @@ def output(self, index): @property def num_outputs(self): - """Number of outputs of this op.""" + """Number of outputs from this op.""" return _api_internal._OpNumOutputs(self) @property @@ -166,7 +166,7 @@ class BaseComputeOp(Operation): """Compute operation.""" @property def axis(self): - """Represent axis of IterVar, defined when it is a ComputeOp""" + """Represent the IterVar axis, defined when it is a ComputeOp""" return self.__getattr__("axis") @property @@ -191,7 +191,7 @@ class ScanOp(Operation): """Scan operation.""" @property def scan_axis(self): - """Represent axis of scan, only defined when it is a ScanOp""" + """Represent the scan axis, only defined when it is a ScanOp""" return self.__getattr__("scan_axis") @@ -205,7 +205,7 @@ class HybridOp(Operation): """Hybrid operation.""" @property def axis(self): - """Represent axis of IterVar, also defined when it is a HybridOp""" + """Represent the IterVar axis, also defined when it is a HybridOp""" return self.__getattr__("axis") diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index f97e6b7579a1..2ef7a4bbb23e 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -50,20 +50,23 @@ class TensorIntrin(NodeBase): decl_tensor_intrin: Construct a TensorIntrin """ def __call__(self, *args, **kwargs): - tensors = [x.tensor for x in args] - regions = [_get_region(x) for x in args] + tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)] + scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)] + regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)] reduce_axis = [] if "reduce_axis" in kwargs: reduce_axis = kwargs["reduce_axis"] if not isinstance(reduce_axis, (list, tuple)): reduce_axis = [reduce_axis] reduce_axis = _api.convert(reduce_axis) - return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) + if scalar_inputs: + scalar_inputs = _api.convert(scalar_inputs) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs) def decl_tensor_intrin(op, fcompute, name="tensor_intrin", - binds=None): + binds=None, scalar_params=None): """Declare a tensor intrinsic function. Parameters @@ -96,6 +99,9 @@ def decl_tensor_intrin(op, requirement of the function. By default, a new compact buffer is created for each tensor in the argument. + scalar_params: a list of variables used by op, whose values will be passed + as scalar_inputs when the tensor intrinsic is called. + Returns ------- intrin: TensorIntrin @@ -122,11 +128,15 @@ def decl_tensor_intrin(op, offset_factor=cfg.offset_factor)) binds_list.append(buf) - body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + if scalar_params: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params) + else: + body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) + scalar_params = [] if isinstance(body, (_expr.Expr, _stmt.Stmt)): body = [body] body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body] if len(body) < 3: body += [None] * (3 - len(body)) return _api_internal._TensorIntrin( - name, op, inputs, binds_list, *body) + name, op, inputs, binds_list, scalar_params, *body) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 25466e08bdf9..02e2c7c67c99 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -18,8 +18,10 @@ [workspace] members = [ "common", + "macros", "runtime", "runtime/tests/test_tvm_basic", + "runtime/tests/test_tvm_dso", "runtime/tests/test_nnvm", "frontend", "frontend/tests/basics", diff --git a/rust/common/build.rs b/rust/common/build.rs index 5dac99ec54bb..919e0adc46c8 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -22,23 +22,30 @@ extern crate bindgen; use std::path::PathBuf; fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let tvm_home = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + tvm_home + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", env!("TVM_HOME")); + println!("cargo:rustc-link-search={}/build", tvm_home); } // @see rust-bindgen#550 for `blacklist_type` bindgen::Builder::default() - .header(format!( - "{}/include/tvm/runtime/c_runtime_api.h", - env!("TVM_HOME") - )) - .header(format!( - "{}/include/tvm/runtime/c_backend_api.h", - env!("TVM_HOME") - )) - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", env!("TVM_HOME"))) + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) .blacklist_type("max_align_t") .layout_tests(false) .derive_partialeq(true) diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml new file mode 100644 index 000000000000..15773b625be9 --- /dev/null +++ b/rust/macros/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-macros" +version = "0.1.0" +license = "Apache-2.0" +description = "Proc macros used by the TVM crates." +repository = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["tvm"] +authors = ["TVM Contributors"] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +goblin = "0.0.22" +proc-macro2 = "0.4" +proc-quote = "0.2" +syn = "0.15" diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs new file mode 100644 index 000000000000..704f7c1de58b --- /dev/null +++ b/rust/macros/src/lib.rs @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![feature(bind_by_move_pattern_guards, proc_macro_span)] + +extern crate proc_macro; + +use std::{fs::File, io::Read}; + +use proc_quote::quote; + +#[proc_macro] +pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let obj_file_path = syn::parse_macro_input!(input as syn::LitStr); + + let mut path = obj_file_path.span().unwrap().source_file().path(); + path.pop(); // remove the filename + path.push(obj_file_path.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } + }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[TVMArgValue]) -> Result { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) +} diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index 8e70565a6c13..3c81a93c9bbf 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -41,7 +41,11 @@ nom = {version = "4.0.0", default-features = false } serde = "1.0.59" serde_derive = "1.0.79" serde_json = "1.0.17" -tvm-common = { version = "0.1.0", path = "../common/" } +tvm-common = { version = "0.1", path = "../common" } +tvm-macros = { version = "0.1", path = "../macros" } [target.'cfg(not(target_env = "sgx"))'.dependencies] num_cpus = "1.8.0" + +[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] +libloading = "0.5" diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index bff02f504a5e..cacd7a38a97f 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -164,7 +164,7 @@ impl<'a> TryFrom<&'a str> for Graph { /// ``` pub struct GraphExecutor<'m, 't> { graph: Graph, - op_execs: Vec>, + op_execs: Vec>, tensors: Vec>, } @@ -240,7 +240,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { graph: &Graph, lib: &'m M, tensors: &Vec>, - ) -> Result>, Error> { + ) -> Result>, Error> { ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); @@ -279,7 +279,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { }) .collect::, Error>>() .unwrap(); - let op: Box = box move || { + let op: Box = box move || { let args = dl_tensors .iter() .map(|t| t.into()) diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index c774d5bbc983..010fbf7d6a29 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -29,7 +29,6 @@ //! For examples of use, please refer to the multi-file tests in the `tests` directory. #![feature( - alloc, allocator_api, box_syntax, fn_traits, @@ -77,6 +76,7 @@ pub use tvm_common::{ packed_func::{self, *}, TVMArgValue, TVMRetValue, }; +pub use tvm_macros::import_module; pub use self::{array::*, errors::*, graph::*, module::*, threading::*, workspace::*}; diff --git a/rust/runtime/src/module/dso.rs b/rust/runtime/src/module/dso.rs new file mode 100644 index 000000000000..3442fad13bf9 --- /dev/null +++ b/rust/runtime/src/module/dso.rs @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + cell::RefCell, + collections::HashMap, + ffi::CStr, + os::raw::{c_char, c_int, c_void}, + pin::Pin, +}; + +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; + +use crate::{ + threading::{TVMBackendParallelBarrier, TVMBackendParallelLaunch}, + workspace::{TVMBackendAllocWorkspace, TVMBackendFreeWorkspace}, + TVMAPISetLastError, +}; + +use super::Module; + +const TVM_MAIN: &'static [u8] = b"__tvm_main__"; +const TVM_MODULE_CTX: &'static [u8] = b"__tvm_module_ctx"; + +/// A module backed by a Dynamic Shared Object (dylib). +pub struct DsoModule<'a> { + lib: libloading::Library, + packed_funcs: RefCell>, + _pin: std::marker::PhantomPinned, +} + +macro_rules! init_context_func { + ($lib:ident, $( ($fn:ident, $sig:ty) ),+ $(,)?) => { + unsafe { + $( + let fn_ptr = $lib.get::<*mut $sig>(concat!("__", stringify!($fn)).as_bytes()); + if let Ok(fn_ptr) = fn_ptr { + **fn_ptr = $fn; + } + )+ + } + }; +} + +impl<'a> DsoModule<'a> { + pub fn new>(filename: P) -> Result>, failure::Error> { + let lib = libloading::Library::new(filename)?; + + init_context_func!( + lib, + (TVMAPISetLastError, extern "C" fn(*const i8)), + ( + TVMBackendAllocWorkspace, + extern "C" fn(c_int, c_int, u64, c_int, c_int) -> *mut c_void + ), + ( + TVMBackendFreeWorkspace, + extern "C" fn(c_int, c_int, *mut c_void) -> c_int + ), + ( + TVMBackendParallelLaunch, + extern "C" fn(crate::threading::FTVMParallelLambda, *const c_void, usize) -> c_int + ), + ( + TVMBackendParallelBarrier, + extern "C" fn(usize, *const tvm_common::ffi::TVMParallelGroupEnv) + ), + ); + + // Pin the module in memory so that `ctx` pointer (below) is stable. + let dso_mod = Box::pin(Self { + lib, + packed_funcs: RefCell::new(HashMap::new()), + _pin: std::marker::PhantomPinned, + }); + + unsafe { + if let Ok(ctx) = dso_mod.lib.get::<*mut *const c_void>(TVM_MODULE_CTX) { + **ctx = &dso_mod as *const _ as *const c_void; + } + } + + Ok(dso_mod) + } +} + +impl<'a> Module for DsoModule<'a> { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)> { + let name = name.as_ref(); + let func = match unsafe { + self.lib + .get::(if name.as_bytes() == TVM_MAIN { + // If __tvm_main__ is present, it contains the name of the + // actual main function. + match self + .lib + .get::<*const c_char>(TVM_MAIN) + .map(|p| CStr::from_ptr(*p)) + { + Ok(m) => m.to_bytes(), + _ => return None, + } + } else { + name.as_bytes() + }) + } { + Ok(func) => unsafe { func.into_raw() }, + Err(_) => return None, + }; + + self.packed_funcs.borrow_mut().insert( + name.to_string(), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), *func)), + ); + + self.packed_funcs.borrow().get(name).map(|f| *f) + } +} + +impl<'a> Drop for DsoModule<'a> { + fn drop(&mut self) { + self.packed_funcs + .replace(HashMap::new()) + .into_iter() + .map(|(_name, f)| unsafe { Box::from_raw(f as *const _ as *mut (dyn PackedFunc)) }) + .for_each(std::mem::drop); + } +} diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs new file mode 100644 index 000000000000..2c7c107f6b30 --- /dev/null +++ b/rust/runtime/src/module/mod.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +mod dso; +mod syslib; + +use tvm_common::{ + ffi::BackendPackedCFunc, + packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, +}; + +#[cfg(not(any(target_arch = "wasm32", target_env = "sgx")))] +pub use dso::DsoModule; +pub use syslib::SystemLibModule; + +pub trait Module { + fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; +} + +// @see `WrapPackedFunc` in `llvm_module.cc`. +fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box { + box move |args: &[TVMArgValue]| { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(tvm_common::errors::FuncCallError::get_with_context( + func_name.clone(), + )) + } + } +} diff --git a/rust/runtime/src/module.rs b/rust/runtime/src/module/syslib.rs similarity index 62% rename from rust/runtime/src/module.rs rename to rust/runtime/src/module/syslib.rs index 865338f848fa..227b8c727e8f 100644 --- a/rust/runtime/src/module.rs +++ b/rust/runtime/src/module/syslib.rs @@ -21,14 +21,9 @@ use std::{ collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, }; -use tvm_common::{ - ffi::BackendPackedCFunc, - packed_func::{PackedFunc, TVMArgValue, TVMRetValue, TVMValue}, -}; +use tvm_common::{ffi::BackendPackedCFunc, packed_func::PackedFunc}; -pub trait Module { - fn get_function>(&self, name: S) -> Option<&(dyn PackedFunc)>; -} +use super::Module; pub struct SystemLibModule; @@ -53,30 +48,6 @@ impl Default for SystemLibModule { } } -// @see `WrapPackedFunc` in `llvm_module.cc`. -pub(super) fn wrap_backend_packed_func( - func_name: String, - func: BackendPackedCFunc, -) -> Box { - box move |args: &[TVMArgValue]| { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(tvm_common::errors::FuncCallError::get_with_context( - func_name.clone(), - )) - } - } -} - #[no_mangle] pub extern "C" fn TVMBackendRegisterSystemLibSymbol( cname: *const c_char, @@ -85,7 +56,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol( let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert( name.to_string(), - &*Box::leak(wrap_backend_packed_func(name.to_string(), func)), + &*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)), ); return 0; } diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index 96143848f5e2..eb2f418473ed 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -42,7 +42,7 @@ use tvm_common::ffi::TVMParallelGroupEnv; #[cfg(target_env = "sgx")] use super::{TVMArgValue, TVMRetValue}; -type FTVMParallelLambda = +pub(crate) type FTVMParallelLambda = extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; /// Holds a parallel job request made by a TVM library function. diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs index ea3bfcb85136..3439f9c2efc7 100644 --- a/rust/runtime/tests/test_tvm_basic/build.rs +++ b/rust/runtime/tests/test_tvm_basic/build.rs @@ -19,13 +19,21 @@ extern crate ar; -use std::{env, path::Path, process::Command}; +use std::{path::PathBuf, process::Command}; use ar::Builder; use std::fs::File; fn main() { - let out_dir = env::var("OUT_DIR").unwrap(); + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest.a"); let output = Command::new(concat!( env!("CARGO_MANIFEST_DIR"), @@ -35,7 +43,7 @@ fn main() { .output() .expect("Failed to execute command"); assert!( - Path::new(&format!("{}/test.o", out_dir)).exists(), + obj_file.exists(), "Could not build tvm lib: {}", String::from_utf8(output.stderr) .unwrap() @@ -45,9 +53,9 @@ fn main() { .unwrap_or("") ); - let mut builder = Builder::new(File::create(format!("{}/libtest.a", out_dir)).unwrap()); - builder.append_path(format!("{}/test.o", out_dir)).unwrap(); + let mut builder = Builder::new(File::create(lib_file).unwrap()); + builder.append_path(obj_file).unwrap(); println!("cargo:rustc-link-lib=static=test"); - println!("cargo:rustc-link-search=native={}", out_dir); + println!("cargo:rustc-link-search=native={}", out_dir.display()); } diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs index 14bb7c20c680..a83078e5834a 100644 --- a/rust/runtime/tests/test_tvm_basic/src/main.rs +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -22,13 +22,14 @@ extern crate ndarray; extern crate tvm_runtime; use ndarray::Array; -use tvm_runtime::{DLTensor, Module, SystemLibModule}; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +mod tvm_mod { + import_module!("../lib/test.o"); +} fn main() { - let syslib = SystemLibModule::default(); - let add = syslib - .get_function("default_function") - .expect("main function not found"); + // try static let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); let mut c = Array::from_vec(vec![0f32; 4]); @@ -36,6 +37,14 @@ fn main() { let mut a_dl: DLTensor = (&mut a).into(); let mut b_dl: DLTensor = (&mut b).into(); let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(tvm_mod::default_function, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); + + // try runtime + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); assert!(c.all_close(&e, 1e-8f32)); } diff --git a/rust/runtime/tests/test_tvm_dso/Cargo.toml b/rust/runtime/tests/test_tvm_dso/Cargo.toml new file mode 100644 index 000000000000..afe7f26e1220 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-tvm-dso" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_tvm_dso/build.rs b/rust/runtime/tests/test_tvm_dso/build.rs new file mode 100644 index 000000000000..f1d9822b01a5 --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{env, path::Path, process::Command}; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/test.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); +} diff --git a/conda/cross-linux.cmake b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py old mode 100644 new mode 100755 similarity index 54% rename from conda/cross-linux.cmake rename to rust/runtime/tests/test_tvm_dso/src/build_test_lib.py index f84ba8e44a26..63b43a5f9bef --- a/conda/cross-linux.cmake +++ b/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,23 +16,25 @@ # specific language governing permissions and limitations # under the License. -# this one is important -set(CMAKE_SYSTEM_NAME Linux) -set(CMAKE_PLATFORM Linux) -#this one not so much -set(CMAKE_SYSTEM_VERSION 1) +"""Prepares a simple TVM library for testing.""" -# specify the cross compiler -set(CMAKE_C_COMPILER $ENV{CC}) +from os import path as osp +import sys -# where is the target environment -set(CMAKE_FIND_ROOT_PATH $ENV{PREFIX} $ENV{BUILD_PREFIX}/$ENV{HOST}/sysroot) +import tvm +from tvm.contrib import cc -# search for programs in the build host directories -set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) -# for libraries and headers in the target directories -set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) -set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + obj_file = osp.join(sys.argv[1], 'test.o') + tvm.build(s, [A, B, C], 'llvm').save(obj_file) + cc.create_shared(osp.join(sys.argv[1], 'test.so'), [obj_file]) -# god-awful hack because it seems to not run correct tests to determine this: -set(__CHAR_UNSIGNED___EXITCODE 1) +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_tvm_dso/src/main.rs b/rust/runtime/tests/test_tvm_dso/src/main.rs new file mode 100644 index 000000000000..953676cea5bb --- /dev/null +++ b/rust/runtime/tests/test_tvm_dso/src/main.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, DsoModule, Module}; + +fn main() { + tvm_runtime::TVMGetLastError(); + let module = DsoModule::new(concat!(env!("OUT_DIR"), "/test.so")).unwrap(); + let add = module + .get_function("__tvm_main__") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index fce73aabf6a7..f31f02b1eaf4 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector") TVM_REGISTER_API("arith.intset_interval") .set_body_typed(IntSet::interval); + TVM_REGISTER_API("arith.DetectLinearEquation") .set_body_typed(DetectLinearEquation); @@ -58,7 +59,6 @@ TVM_REGISTER_API("arith.DeduceBound") TVM_REGISTER_API("arith.DomainTouched") .set_body_typed(DomainTouched); - TVM_REGISTER_API("_IntervalSetGetMin") .set_body_method(&IntSet::min); @@ -71,11 +71,19 @@ TVM_REGISTER_API("_IntSetIsNothing") TVM_REGISTER_API("_IntSetIsEverything") .set_body_method(&IntSet::is_everything); +ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { + return ConstIntBound(min_value, max_value); +} + TVM_REGISTER_API("arith._make_ConstIntBound") -.set_body_typed(ConstIntBoundNode::make); +.set_body_typed(MakeConstIntBound); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { + return ModularSet(coeff, base); +} TVM_REGISTER_API("arith._make_ModularSet") -.set_body_typed(ModularSetNode::make); +.set_body_typed(MakeModularSet); TVM_REGISTER_API("arith._CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -103,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->int_set(args[0], args[1]); + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { auto& sptr = args[1].node_sptr(); @@ -116,8 +128,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { // can't use make_shared due to noexcept(false) decl in destructor, // see https://stackoverflow.com/a/43907314 - auto ctx = - std::shared_ptr(new ConstraintContext(self.get(), args[0])); + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6195aac1b93f..e5b003cafb87 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -130,6 +130,7 @@ REGISTER_PASS(RewriteUnsafeSelect); REGISTER_PASS(Inline); REGISTER_PASS(IRTransform); REGISTER_PASS(VectorizeLoop); +REGISTER_PASS(SkipVectorize); REGISTER_PASS(UnrollLoop); REGISTER_PASS(InjectCopyIntrin); REGISTER_PASS(ThreadSync); @@ -151,6 +152,7 @@ REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); +REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); REGISTER_PASS(VerifyMemory); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 420d6f9c1d0d..10a1c7f041c3 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,7 +31,8 @@ Analyzer::Analyzer() : const_int_bound(this), modular_set(this), rewrite_simplify(this), - canonical_simplify(this) { + canonical_simplify(this), + int_set(this) { } void Analyzer::Bind(const VarExpr& v, const Expr& expr) { @@ -54,10 +55,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) { // skip rewrite simplify } -ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { + +void ConstraintContext::EnterWithScope() { + CHECK(exit_ == nullptr); // entering the scope. - auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); - auto f1 = analyzer->modular_set.EnterConstraint(constraint); + auto f0 = analyzer_->const_int_bound.EnterConstraint(constraint_); + auto f1 = analyzer_->modular_set.EnterConstraint(constraint_); // recovery function. exit_ = [f0, f1]() { if (f1 != nullptr) f1(); @@ -65,9 +68,14 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) }; } +void ConstraintContext::ExitWithScope() { + CHECK(exit_ != nullptr); + exit_(); +} + bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { - return ptr->value > lower_bound; + return ptr->value >= lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); if (bd->min_value >= lower_bound) return true; diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 89e556c6f75f..395a371f43af 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,12 +30,12 @@ #include #include +#include "int_set.h" namespace tvm { namespace arith { using namespace ir; -using HalideIR::Internal::Interval; // a visitor to find the path to the target variable // from a expression. @@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e, BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success) return IntSet::nothing(); - Expr min = Interval::neg_inf, max = Interval::pos_inf; + Expr min = neg_inf(), max = pos_inf(); if (d.is_greater) { min = d.result; } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1bf1f84fb635..a50cbfb96591 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file canonical_simplify.cc * \brief Canonical form based simplification. */ @@ -763,7 +762,10 @@ Mutate_(const Mod* op, const Expr& self) { if (TryCompare(temp, cval) == kLT) { return temp; } else { - return SplitModConst(ToSplitExpr(temp), cval); + // contonue to use logic below. + a = extra; + psum = a.as(); + CHECK(psum != nullptr); } } } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index ff2fb8dbd4ac..cc54bff596be 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,8 +27,8 @@ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #include -#include #include +#include namespace tvm { namespace arith { @@ -105,12 +105,12 @@ inline Expr ComputeExpr(Expr a, Expr b) { template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_max(a, b); + return max(a, b); } template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_min(a, b); + return min(a, b); } template diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index fbf8fe7e6f89..ec50aef5c51e 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -206,6 +206,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -216,6 +217,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -307,6 +309,58 @@ inline Expr TryConstFold(Expr a) { return Expr(); } +/*! \brief Helper namespace for symbolic value limits */ +struct SymbolicLimits { + /*! \brief positive infinity */ + static Expr pos_inf_; + /*! \brief negative infinity */ + static Expr neg_inf_; +}; + +/*! + * \brief Opaque expression representing positive infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return positive infinity. + */ +inline Expr pos_inf() { + return SymbolicLimits::pos_inf_; +} + +/*! + * \brief Check if value is positive infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_pos_inf(const Expr& value) { + return value.same_as(SymbolicLimits::pos_inf_); +} + +/*! + * \brief Opaque expression representing negative infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return negative infinity. + */ +inline Expr neg_inf() { + return SymbolicLimits::neg_inf_; +} + +/*! + * \brief Check if value is negative infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_neg_inf(const Expr& value) { + return value.same_as(SymbolicLimits::neg_inf_); +} + } // namespace arith } // namespace tvm #endif // TVM_ARITHMETIC_CONST_FOLD_H_ diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index bfd06c8ba255..ed8faba3509b 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -34,12 +34,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound ConstIntBoundNode::make( +ConstIntBound::ConstIntBound( int64_t min_value, int64_t max_value) { auto node = make_node(); node->min_value = min_value; node->max_value = max_value; - return ConstIntBound(node); + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -190,7 +190,7 @@ class ConstIntBoundAnalyzer::Impl : std::min(a.max_value, b_max_cap)); } else { return MakeBound(std::max(a.min_value, -b_max_cap), - std::min(a.max_value, b_max_cap)); + std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); } } else { CHECK(!b.is_const(0)) << "mod by zero"; @@ -289,8 +289,8 @@ class ConstIntBoundAnalyzer::Impl : std::vector additional_info_; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. - static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf; - static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf; + static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; + static const constexpr int64_t kPosInf = ConstIntBound::kPosInf; static_assert(-kNegInf == kPosInf, "invariant of inf"); // internal helper functions /*! @@ -462,7 +462,7 @@ class ConstIntBoundAnalyzer::Impl : ConstIntBound ConstIntBoundAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ConstIntBoundNode::make(ret.min_value, ret.max_value); + return ConstIntBound(ret.min_value, ret.max_value); } void ConstIntBoundAnalyzer::Update(const Var& var, diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 2fe21fef7e21..e584c8b1ce33 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,8 +19,8 @@ /*! * Copyright (c) 2017 by Contributors - * \file bound_deducer.cc - * \brief Utility to deduce bound of expression + * \file detect_linear_equation.cc + * \brief Utility to detect patterns in the expression. */ #include #include diff --git a/src/arithmetic/int_op_overflow.h b/src/arithmetic/int_op_overflow.h index 87f4f059e858..b78f21cb1dba 100644 --- a/src/arithmetic/int_op_overflow.h +++ b/src/arithmetic/int_op_overflow.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index abbb7cd9744e..75a4aaf83ab6 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,201 +18,55 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file int_set.cc * \brief The integer set functions */ #include -#include -#include #include -#include +#include +#include +#include #include -#include "compute_expr.h" -#include "int_set_internal.h" +#include "int_set.h" +#include "pattern_match.h" namespace tvm { namespace arith { -using HalideIR::Internal::Interval; -using namespace ir; - -inline IntSet IntSet::cover_interval() const { - if ((*this).as()) return *this; - const StrideSet* s = (*this).as(); - if (s) { - CHECK_NE(s->extents.size(), 0U); - Expr max = s->base.max; - for (size_t i = 0; i < s->extents.size(); ++i) { - max = max + s->extents[i] * s->strides[i] - s->strides[i]; - } - return IntervalSet::make(s->base.min, Simplify(max)); - } - LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval"; - return IntSet::everything(); -} - -Range IntSet::cover_range(Range max_range) const { - IntSet temp; - const IntervalSet* s_int = (*this).as(); - if (s_int == nullptr) { - temp = this->cover_interval(); - s_int = temp.as(); - } - if (s_int->i.is_bounded()) { - return Range::make_by_min_extent( - s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); - } - return max_range; -} - -Expr IntSet::min() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.min; -} - -Expr IntSet::max() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.max; -} - -bool IntSet::is_nothing() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_empty()); -} - -bool IntSet::is_everything() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_everything()); -} +Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle()); +Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle()); -bool IntSet::is_single_point() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_single_point()); +IntervalSet::IntervalSet(Expr min_value, Expr max_value) { + auto node = make_node(); + node->min_value = std::move(min_value); + node->max_value = std::move(max_value); + node_ = std::move(node); } -bool IntSet::can_prove_positive() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); +IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { + return IntervalSet(min_value, max_value); } -bool IntSet::can_prove_negative() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); -} +TVM_REGISTER_API("arith._make_IntervalSet") +.set_body_typed(MakeIntervalSet); -bool IntSet::can_prove_non_positive() const { - if (const IntervalSet* s_int = (*this).as()) { - auto max = ir::Simplify(s_int->i.max); - return is_zero(max) || is_negative_const(max); - } - return false; -} -bool IntSet::can_prove_non_negative() const { - if (const IntervalSet* s_int = (*this).as()) { - // Any reason why we should or should not use can_prove() to implement - // these functions? - auto min = ir::Simplify(s_int->i.min); - return is_zero(min) || is_positive_const(min); - } - return false; -} - - -SignType IntSet::sign_type() const { - if (can_prove_positive()) { - return kPositive; - } else if (can_prove_negative()) { - return kNegative; - } else if (is_single_point() && is_zero(point_value())) { - return kZero; +IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = min(a->max_value, b->max_value); + Expr min_value = max(a->min_value, b->min_value); + if ((max_value.type().is_int() || max_value.type().is_uint()) && + (min_value.type().is_int() || min_value.type().is_uint()) && + analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { + return IntervalSet::Empty(); } else { - return kUnknown; - } -} -Expr IntSet::point_value() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int && s_int->i.is_single_point()); - return s_int->i.min; -} - -IntSet IntSet::nothing() { - return IntervalSet::make(Interval::nothing()); -} - -IntSet IntSet::everything() { - return IntervalSet::make(Interval::everything()); -} - -IntSet IntSet::single_point(Expr x) { - return IntervalSet::make(Interval::single_point(x)); -} - -IntSet IntSet::range(Range r) { - // must make sure it can be matched back by MatchRange. - if (is_one(r->extent)) { - return IntSet::single_point(r->min); - } - if (is_positive_const(r->extent) && is_const(r->min)) { - return IntervalSet::make( - r->min, ComputeExpr(ComputeExpr(r->extent, r->min), 1)); - } - return IntervalSet::make(r->min, (r->extent + r->min) - 1); -} - -IntSet IntSet::interval(Expr min, Expr max) { - if (min.same_as(max)) { - return IntSet::single_point(min); - } - return IntervalSet::make(min, max); -} - -inline bool prove_equal(Expr lhs, Expr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); -} - -// Check if a is created from b. -bool IntSet::match_range(const Range& b) const { - const IntSet& a = *this; - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return prove_equal(i.min, b->min) && - prove_equal(i.max, ComputeExpr(ComputeExpr(b->extent, b->min), 1)); -} - -inline bool MatchPoint(const IntSet& a, - const Expr& b) { - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return i.is_single_point() && i.min.same_as(b); -} - -IntSet Union(const Array& sets) { - if (sets.size() == 0) return IntSet::nothing(); - if (sets.size() == 1) return sets[0]; - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - IntSet s = sets[i].cover_interval(); - const Interval& y = s.as()->i; - x.include(y); + return IntervalSet(min_value, max_value); } - x.max = ir::Simplify(x.max); - x.min = ir::Simplify(x.min); - return IntervalSet::make(x); } -IntSet Intersect(const Array& sets) { - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - Interval y = sets[i].cover_interval().as()->i; - x = Interval::make_intersection(x, y); - } - return IntervalSet::make(x); +IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = max(a->max_value, b->max_value); + Expr min_value = min(a->min_value, b->min_value); + return IntervalSet(min_value, max_value); } // type traits @@ -227,407 +81,623 @@ struct is_logical_op { static const bool value = true; \ }; -// interval related. -template -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key; - return IntSet::everything(); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); + +/*! + * \brief Combine two interval set under arithmetic operations. + * \note this can possibly relax the set. + */ +template +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + Expr res = TryConstFold(a->min_value, b->min_value); + if (!res.defined()) res = Op::make(a->min_value, b->min_value); + return IntervalSet::SinglePoint(res); + } + if (is_logical_op::value) { + return IntervalSet(make_const(a->min_value.type(), 0), + make_const(a->min_value.type(), 1)); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsEverything()) return a; + if (b->IsEverything()) return b; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_lower_bound()) { - r.min = ComputeExpr(a.min, b.min); - } - if (a.has_upper_bound() && b.has_upper_bound()) { - r.max = ComputeExpr(a.max, b.max); - } - return IntervalSet::make(r); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value + b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasLowerBound() ? + a->min_value + b->min_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasUpperBound() ? + a->max_value + b->max_value : pos_inf(); + return IntervalSet(min_value, max_value); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value - b->min_value); } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_upper_bound()) { - r.min = ComputeExpr(a.min, b.max); - } - if (a.has_upper_bound() && b.has_lower_bound()) { - r.max = ComputeExpr(a.max, b.min); - } - return IntervalSet::make(r); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasUpperBound() ? + a->min_value - b->max_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasLowerBound() ? + a->max_value - b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); } + template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - if (a.is_single_point() && !b.is_single_point()) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value * b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsSinglePoint()) { std::swap(a, b); } - if (b.is_single_point()) { - if (is_zero(b.min)) return IntSet::single_point(0); - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr(a.max, b.min) : a.max; - // no relaxation is needed in here due to set is inclusive - // TODO(tqchen): consider convert to StrideSet. - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) return b; + if (is_one(b->min_value)) return a; + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value * b->min_value; + Expr e2 = a->max_value * b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Mul"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mul"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval
(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr
(a.min, b.min)); - } - if (b.is_single_point()) { - if (is_zero(b.min)) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value / b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) { LOG(FATAL) << "Divide by zero in CombineInterval Div"; } - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr
(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr
(a.max, b.min) : a.max; + if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value / b->min_value; + Expr e2 = a->max_value / b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Div"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Div"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value % b->min_value); } - if (b.is_single_point()) { - Expr divisor = b.min; + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + + if (b->IsSinglePoint()) { + const Expr& divisor = b->min_value; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } - return IntervalSet::make(make_zero(divisor.type()), divisor - 1); + // We need to add more bound constraints throughout the code. + // The logic below assumes a is non-negative, which usually + // is the case of our application. + // TODO(tqchen): add bound constraints for a. + if (analyzer->CanProveGreaterEqual(divisor, 0)) { + return IntervalSet(make_zero(divisor.type()), divisor - 1); + } else { + Expr bound = abs(divisor) - 1; + return IntervalSet(-bound, bound); + } } - - LOG(WARNING) << "Return Everything in CombineInterval Mod"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mod"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } - return IntervalSet::make(Interval::make_max(a.min, b.min), - Interval::make_max(a.max, b.max)); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(max(a->min_value, b->min_value), + max(a->max_value, b->max_value)); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } - return IntervalSet::make(Interval::make_min(a.min, b.min), - Interval::make_min(a.max, b.max)); -} - -template -inline IntSet CombineInterval_(IntSet a, IntSet b) { - return CombineInterval( - a.as()->i, b.as()->i); -} - -// stride related -inline IntSet AsStrideSet(IntSet a) { - if (a.as()) return a; - const IntervalSet* s = a.as(); - CHECK(s->i.is_bounded()); - NodePtr n = make_node(); - n->base = s->i; - return IntSet(n); -} -template -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineInterval_(a.cover_interval(), b.cover_interval()); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(min(a->min_value, b->min_value), + min(a->max_value, b->max_value)); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && is_zero(a_int->i.min)) return b; - if (b_int && is_zero(b_int->i.min)) return a; - a = AsStrideSet(a); - b = AsStrideSet(b); - const StrideSet* a_stride = a.as(); - const StrideSet* b_stride = b.as(); - auto n = make_node(*a_stride); - for (size_t i = 0; i < b_stride->extents.size(); ++i) { - n->extents.push_back(b_stride->extents[i]); - n->strides.push_back(b_stride->strides[i]); - } - n->base = CombineInterval( - a_stride->base, b_stride->base).as()->i; - return IntSet(n); -} - -inline IntSet NegateSet(IntSet a) { - const IntervalSet* a_int = a.as(); - if (a_int) { - if (a_int->i.is_single_point()) { - return IntSet::single_point(-a_int->i.min); - } else { - Interval r = Interval::everything(); - if (a_int->i.has_upper_bound()) { - r.min = -(a_int->i.max); - } - if (a_int->i.has_lower_bound()) { - r.max = -(a_int->i.min); - } - return IntervalSet::make(r); - } - } else { - return NegateSet(a.cover_interval()); +// internal helper function to get an interval set +IntervalSet ToIntervalSet(IntSet set) { + if (auto* node = set.as()) { + return GetRef(node); } + DLOG(INFO) << "cannot resolve int set " << set; + return IntervalSet::Everything(); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineSets(a, NegateSet(b)); -} - -TVM_DECLARE_LOGICAL_OP(And); -TVM_DECLARE_LOGICAL_OP(Or); -TVM_DECLARE_LOGICAL_OP(EQ); -TVM_DECLARE_LOGICAL_OP(NE); -TVM_DECLARE_LOGICAL_OP(GE); -TVM_DECLARE_LOGICAL_OP(GT); -TVM_DECLARE_LOGICAL_OP(LE); -TVM_DECLARE_LOGICAL_OP(LT); -TVM_DECLARE_LOGICAL_OP(Not); +using namespace ir; -// generic combine operations of two sets -template -inline IntSet Combine(const IntSet& a, const IntSet &b) { - if (is_logical_op::value) { - return IntervalSet::make(0, 1); +// Simplified version of int set evaluator that operates on IntervalSet +// We might use better set analysis in the future to replace the intervalset. +class IntervalSetEvaluator : + public ExprFunctor { + public: + IntervalSetEvaluator(Analyzer* analyzer, + const Map& dom_map, + bool eval_vec = false) + : analyzer_(analyzer), + dom_map_(dom_map), + eval_vec_(eval_vec) { } - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && a_int->i.is_everything()) return a; - if (b_int && b_int->i.is_everything()) return b; - if (a_int && b_int) { - return CombineInterval(a_int->i, b_int->i); + + IntervalSet Eval(const Expr& val) { + return this->VisitExpr(val); } - if (a_int && !(a_int->i.is_bounded())) { - return CombineInterval_(a, b.cover_interval()); + + IntervalSet VisitExpr_(const IntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - if (b_int && !(b_int->i.is_bounded())) { - return CombineInterval_(a.cover_interval(), b); + + IntervalSet VisitExpr_(const UIntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - return CombineSets(a, b); -} -class IntSetEvaluator : - public ExprFunctor { - public: - explicit IntSetEvaluator( - const std::unordered_map& dom_map, - bool eval_vec = false) - : dom_map_(dom_map), eval_vec_(eval_vec) {} - // Evaluate. - IntSet Eval(const Expr& e) { - return this->VisitExpr(e, e); - } - IntSet VisitExpr_(const IntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const UIntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const Variable* op, const Expr& e) final { - auto it = dom_map_.find(op); + IntervalSet VisitExpr_(const Variable* op) final { + Var var = GetRef(op); + auto it = dom_map_.find(var); if (it != dom_map_.end()) { - return it->second; + return ToIntervalSet((*it).second); } else { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(var); } } - IntSet VisitExpr_(const Add* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Add* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Sub* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Sub* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mul* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mul* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Div* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Div* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mod* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mod* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Min* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Min* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Max* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Max* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const EQ* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const EQ* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const NE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const NE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const And* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const And* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Or* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Or* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Ramp* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Ramp* op) final { CHECK(eval_vec_); - IntSet base = Eval(op->base); - int vstride; - if (GetConstInt(op->stride, &vstride)) { + IntervalSet base = Eval(op->base); + PVar stride; + if (stride.Match(op->stride)) { Type t = op->base.type(); - if (vstride > 0) { + int64_t vstride = stride.Eval()->value; + if (vstride> 0) { return Combine( + analyzer_, base, - IntSet::interval(make_zero(t), - make_const(t, vstride * op->lanes -1))); + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { return Combine( + analyzer_, base, - IntSet::interval(make_const(t, vstride * op->lanes + 1), - make_zero(t))); + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } - LOG(WARNING) << "cannot evaluate set on expression " << e; - return IntSet::everything(); + DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + return IntervalSet::Everything(); } - IntSet VisitExpr_(const Broadcast* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Broadcast* op) final { CHECK(eval_vec_); - return Eval(op->value); + return VisitExpr(op->value); } - IntSet VisitExpr_(const Select* op, const Expr& e) final { - IntSet true_set = this->Eval(op->true_value); - IntSet false_set = this->Eval(op->false_value); - return Union({false_set, true_set}); + + IntervalSet VisitExpr_(const Select* op) final { + IntervalSet true_set = this->Eval(op->true_value); + IntervalSet false_set = this->Eval(op->false_value); + return Union(analyzer_, false_set, true_set); } - IntSet VisitExprDefault_(const Node* op, const Expr& e) final { - LOG(WARNING) << "cannot evaluate set type " << e->type_key(); - return IntSet::everything(); + + IntervalSet VisitExprDefault_(const Node* op) final { + DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + return IntervalSet::Everything(); } private: + // whether set is exactly single point that equals value. + bool MatchPoint(const IntervalSet& set, + const Expr& value) const { + return set->min_value.same_as(value) && set->max_value.same_as(value); + } + template - inline IntSet Binary(const T* op, const Expr& e) { - IntSet a = this->Eval(op->a); - IntSet b = this->Eval(op->b); + inline IntervalSet VisitBinaryExpr_(const T* op) { + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(a, b); + return Combine(analyzer_, a, b); } - const std::unordered_map& dom_map_; + Analyzer* analyzer_; + const Map& dom_map_; bool eval_vec_{false}; }; +class IntSetAnalyzer::Impl { + public: + explicit Impl(Analyzer* analyzer) + : analyzer_(analyzer) { + } + + IntSet Eval(const Expr& expr, const Map& dom_map) const { + return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); + } + + private: + Analyzer* analyzer_; +}; + +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +IntSetAnalyzer::~IntSetAnalyzer() { + delete impl_; +} + +IntSet IntSetAnalyzer::operator()(const Expr& expr, + const Map& dom_map) { + return impl_->Eval(expr, dom_map); +} + +// Quickly adapt to IntSet interface +// TODO(tqchen): revisit IntSet interface as well. +Range IntSet::cover_range(Range max_range) const { + IntSet temp; + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int != nullptr); + if (s_int->HasUpperBound() && s_int->HasLowerBound()) { + return Range::make_by_min_extent( + s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + } + return max_range; +} + +Expr IntSet::min() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->min_value; +} + +Expr IntSet::max() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->max_value; +} + +bool IntSet::is_nothing() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEmpty()); +} + +bool IntSet::is_everything() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEverything()); +} + +bool IntSet::is_single_point() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsSinglePoint()); +} + +bool IntSet::can_prove_positive() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_positive_const(ir::Simplify(s_int->min_value))); +} + +bool IntSet::can_prove_negative() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_negative_const(ir::Simplify(s_int->max_value))); +} + +bool IntSet::can_prove_non_positive() const { + if (const auto* s_int = (*this).as()) { + auto max = ir::Simplify(s_int->max_value); + return is_zero(max) || is_negative_const(max); + } + return false; +} + +bool IntSet::can_prove_non_negative() const { + if (const IntervalSetNode* s_int = (*this).as()) { + auto min = ir::Simplify(s_int->min_value); + return is_zero(min) || is_positive_const(min); + } + return false; +} + +SignType IntSet::sign_type() const { + if (can_prove_positive()) { + return kPositive; + } else if (can_prove_negative()) { + return kNegative; + } else if (is_single_point() && is_zero(point_value())) { + return kZero; + } else { + return kUnknown; + } +} +Expr IntSet::point_value() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int && s_int->IsSinglePoint()); + return s_int->min_value; +} + +IntSet IntSet::nothing() { + return IntervalSet::Empty(); +} + +IntSet IntSet::everything() { + return IntervalSet::Everything(); +} + +IntSet IntSet::single_point(Expr x) { + return IntervalSet::SinglePoint(x); +} + +IntSet IntSet::interval(Expr min, Expr max) { + if (min.same_as(max)) { + return IntSet::single_point(min); + } + return IntervalSet(min, max); +} + +// Range related code +inline bool ProveEqual(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + +IntSet IntSet::range(Range r) { + // must make sure it can be matched back by MatchRange. + if (is_one(r->extent)) { + return IntSet::single_point(r->min); + } + return IntervalSet(r->min, r->extent + r->min - 1); +} + +bool IntSet::match_range(const Range& b) const { + const IntSet& a = *this; + const IntervalSetNode* a_int = a.as(); + if (!a_int) return false; + return ProveEqual(a_int->min_value, b->min) && + ProveEqual(a_int->max_value, b->extent + b->min - 1); +} + +IntSet Union(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Union(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +IntSet Intersect(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Intersect(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +Map ConvertDomMap(const Map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(kv.first->var, kv.second); + } + return dmap; +} + +Map ConvertDomMap( + const std::unordered_map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(GetRef(kv.first), kv.second); + } + return dmap; +} + IntSet EvalSet(Expr e, - const std::unordered_map& dom_map) { - return IntSetEvaluator(dom_map, false).Eval(e); + const Map& dom_map) { + Analyzer ana; + return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } IntSet IntSet::vector(Expr x) { - std::unordered_map dmap; - return IntSetEvaluator(dmap, true).Eval(x); + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } IntSet EvalSet(Expr e, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(e, dmap); + return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, +IntSet EvalSet(Expr e, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - IntSet min_set = m.Eval(r->min).cover_interval(); + return EvalSet(e, ConvertDomMap(dom_map)); +} + +IntSet EvalSet(Range r, + const Map& dom_map) { + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + IntervalSet min_set = m.Eval(r->min); // Simplifying first can give tighter bounds if r->min and r->extent share variables - Expr sum = ComputeExpr(ComputeExpr(r->min, r->extent), 1); - IntSet max_set = m.Eval(Simplify(sum)).cover_interval(); - const Interval& ni = min_set.as()->i; - const Interval& xi = max_set.as()->i; - if (!ni.has_lower_bound()) return IntSet::everything(); - if (!xi.has_upper_bound()) return IntSet::everything(); - return IntervalSet::make(ni.min, xi.max); + Expr sum = r->min + r->extent - 1; + IntervalSet max_set = m.Eval(Simplify(sum)); + if (!min_set->HasLowerBound()) return IntSet::everything(); + if (!max_set->HasUpperBound()) return IntSet::everything(); + return IntervalSet(min_set->min_value, max_set->max_value); } -IntSet EvalSet(IntSet s, +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - s = s.cover_interval(); - const IntervalSet* s_int = s.as(); - Expr vmax = s_int->i.has_upper_bound() ? - m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max; - Expr vmin = s_int->i.has_lower_bound() ? - m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min; - return IntervalSet::make(vmin, vmax); + return EvalSet(r, ConvertDomMap(dom_map)); } -class SubExprIntSetEvaluator : public IntSetEvaluator { +IntSet EvalSet(IntSet s, + const std::unordered_map& dom_map) { + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + IntervalSetEvaluator m(&ana, dmap); + const IntervalSetNode* s_int = s.as(); + Expr vmax = s_int->HasUpperBound() ? + m.Eval(s_int->max_value).max() : s_int->max_value; + Expr vmin = s_int->HasLowerBound() ? + m.Eval(s_int->min_value).min() : s_int->min_value; + return IntervalSet(vmin, vmax); +} + +class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntSetEvaluator( - const std::unordered_map& dom_map) - : IntSetEvaluator(dom_map) {} + explicit SubExprIntervalSetEvaluator( + Analyzer* analyzer, + const Map& dom_map) + : IntervalSetEvaluator(analyzer, dom_map) {} - IntSet VisitExpr(const Expr& n, const Expr& e) final { - IntSet ret = IntSetEvaluator::VisitExpr(n, e); + IntervalSet VisitExpr(const Expr& n) final { + IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); expr_map[n] = ret; return ret; } @@ -635,28 +705,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr(Expr e, +ExprIntSetMap EvalSetForEachSubExpr( + Expr e, const std::unordered_map& dom_map) { - SubExprIntSetEvaluator m(dom_map); + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + SubExprIntervalSetEvaluator m(&ana, dmap); m.Eval(e); return m.expr_map; } IntSet EvalSet(Range r, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(r, dmap); + return EvalSet(r, ConvertDomMap(dom_map)); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const IntervalSet *op, IRPrinter *p) { - p->stream << "interval-set" - << "[" << op->i.min << ", " - << op->i.max << ']'; +.set_dispatch([](const IntervalSetNode *op, IRPrinter *p) { + p->stream << "IntervalSet" + << "[" << op->min_value << ", " + << op->max_value << ']'; }); - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h new file mode 100644 index 000000000000..bf7fec24f78a --- /dev/null +++ b/src/arithmetic/int_set.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file int_set.h + * \brief Internal data structure for integer set. + */ +#ifndef TVM_ARITHMETIC_INT_SET_H_ +#define TVM_ARITHMETIC_INT_SET_H_ + +#include +#include +#include +#include "const_fold.h" + +namespace tvm { +namespace arith { + +/*! + * \brief Symbolic interval set. + * + * \note We intentionally keep the internal of IntSet private, + as we might change it later. + */ +class IntervalSetNode : public IntSetNode { + public: + /*! \brief Minimum value in the interval. */ + Expr min_value; + /*! \brief Maximum value in the interval. */ + Expr max_value; + + // visitor overload. + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("min_value", &min_value); + v->Visit("max_value", &max_value); + } + + /*! \return Whether the interval has upper bound. */ + bool HasUpperBound() const { + return !is_pos_inf(max_value) && !IsEmpty(); + } + /*! \return Whether the interval has lower bound. */ + bool HasLowerBound() const { + return !is_neg_inf(min_value) && !IsEmpty(); + } + /*! \return Whether the interval is a single point. */ + bool IsSinglePoint() const { + return min_value.same_as(max_value); + } + /*! \return whether interval represent nothing */ + bool IsEmpty() const { + // during computations, either extreme could occur. + return is_pos_inf(min_value) || is_neg_inf(max_value); + } + /*! \return whether interval represent everything */ + bool IsEverything() const { + return is_neg_inf(min_value) && is_pos_inf(max_value); + } + + static constexpr const char* _type_key = "arith.IntervalSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); +}; + +/*! + * \brief Interval set used for symbolic integer analysis. + * \sa IntervalSetNode + */ +class IntervalSet : public IntSet { + public: + /*! + * \brief Make a new instance of interval set. + * \param min_value The minimum value in the interval. + * \param max_value The maximum value in the interval. + * \return The created set. + */ + TVM_DLL IntervalSet(Expr min_value, Expr max_value); + + /*! + * \brief Create an IntervalSet that represents a single point. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet SinglePoint(Expr value) { + return IntervalSet(value, value); + } + /*! + * \brief Create an IntervalSet that represents everything. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet Everything() { + return IntervalSet(neg_inf(), pos_inf()); + } + /*! + * \brief Create an empty eet. + * \return The result set. + */ + static IntervalSet Empty() { + return IntervalSet(pos_inf(), neg_inf()); + } + + TVM_DEFINE_NODE_REF_COW(IntervalSetNode); + TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); +}; + +/*! + * \brief Create union of two IntervalSets. + * \param analyzer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); + +/*! + * \brief Create insersection of two IntervalSets. + * \param analzyer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITHMETIC_INT_SET_H_ diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h deleted file mode 100644 index 8b675cfbffda..000000000000 --- a/src/arithmetic/int_set_internal.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2017 by Contributors - * \file int_set_internal.h - * \brief Implementations of integer set - */ -#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_ -#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -using HalideIR::Internal::Interval; - -/*! \brief Set of continuous interval */ -struct IntervalSet : public IntSetNode { - /*! \brief the internal interval*/ - Interval i; - - static IntSet make(Interval i) { - NodePtr n = - make_node(); - n->i = i; - return IntSet(n); - } - static IntSet make(Expr min, Expr max) { - NodePtr n = - make_node(); - n->i.min = min; - n->i.max = max; - return IntSet(n); - } - - static constexpr const char* _type_key = "IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode); -}; - -/*! - * \brief set represented by strided integers - * Reserved for cases where strided access is supported. - */ -struct StrideSet : public IntSetNode { - /*! \brief the base inetrval */ - Interval base; - /*! \brief additional extents in positive number */ - Array extents; - /*! \brief additional strides in positive number */ - Array strides; - - static constexpr const char* _type_key = "StrideSet"; - TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); -}; - -} // namespace arith -} // namespace tvm - -#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_ diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 7701e04844fa..b3e943fc7631 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include "pattern_match.h" namespace tvm { @@ -35,11 +37,12 @@ using namespace ir; TVM_REGISTER_NODE_TYPE(ModularSetNode); -ModularSet ModularSetNode::make(int64_t coeff, int64_t base) { +ModularSet::ModularSet(int64_t coeff, int64_t base) { auto node = make_node(); node->coeff = coeff; node->base = base; - return ModularSet(node); + // finish construction. + node_ = std::move(node); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -366,13 +369,13 @@ class ModularSetAnalyzer::Impl : * \return Bound that represent everything dtype can represent. */ static Entry Nothing() { - return Entry(0, 1); + return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { Entry ret = impl_->VisitExpr(expr); - return ModularSetNode::make(ret.coeff, ret.base); + return ModularSet(ret.coeff, ret.base); } void ModularSetAnalyzer::Update(const Var& var, diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 58d2b83a223a..ea6530631880 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -80,12 +80,6 @@ TryCompare(const Expr& x, int64_t val) { return kLT; } } - if (val == 0) { - ModularSet dmod = parent_->modular_set(diff); - if (dmod->base != 0) { - return kNE; - } - } ConstIntBound dbound = parent_->const_int_bound(diff); if (dbound->min_value > val) { return kGT; @@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) { if (dbound->max_value <= val) { return kLE; } + if (val == 0) { + ModularSet dmod = parent_->modular_set(diff); + if (dmod->base != 0) { + return kNE; + } + } return kUnknown; } @@ -284,11 +284,39 @@ Mutate_(const Sub* op, const Expr& self) { CanProveEqual(((b1 - s2) - (b2 - s1)).Eval(), 0)); // modular-div simplification - // Always pre-condition on positive integer domain + // Note that c*(x/c) + x % c == x is true for every x and c != 0 even for truncated division TVM_TRY_REWRITE_IF(x - (x / c1) * c1, x % c1, - CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + c1.Eval()->value != 0); TVM_TRY_REWRITE_IF((x / c1) * c1 - x, 0 - (x % c1), - CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - ((x + y) / c1) * c1, (x + y) % c1 - y, + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(((x + y) / c1) * c1 - x, y - ((x + y) % c1), + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - ((x - y) / c1) * c1, (x - y) % c1 + y, + c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(((x - y) / c1) * c1 - x, 0 - (x - y) % c1 - y, + c1.Eval()->value != 0); + + TVM_TRY_REWRITE_IF(x * c2 - (x / c1) * c3, (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x / c1) * c3 - x * c2, 0 - (x % c1) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x + y) / c1) * c3, ((x + y) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x + y) / c1) * c3 - x * c2, (y - ((x + y) % c1)) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(x * c2 - ((x - y) / c1) * c3, ((x - y) % c1 + y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, (0 - (x - y) % c1 - y) * c2, + c1.Eval()->value != 0 && + c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && @@ -348,6 +376,7 @@ Mutate_(const Mul* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); TVM_TRY_RECURSIVE_REWRITE_IF( (x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); @@ -396,6 +425,16 @@ Mutate_(const Div* op, const Expr& self) { // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. + // TryConstFold doesn't work for negative cases because it is also used by legacy + // parts of tvm which still assume euclidean div. In this simplifier we assume that the division + // is truncated, so perform const folding again. + // NOTE: trunc div required + if ((c1 / c2).Match(ret)) { + int64_t c1val = c1.Eval()->value; + int64_t c2val = c2.Eval()->value; + return make_const(op->type, c1val / c2val); + } + // while it is always true for trunc div // restrict to common case(positive div) TVM_TRY_REWRITE_IF((x / c1) / c2, x / (c1 * c2), @@ -595,10 +634,12 @@ Mutate_(const Mod* op, const Expr& self) { TVM_TRY_REWRITE_IF((x * c1 + y) % c2, y % c2, c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF((x + c1) % c2, x % c2, c2.Eval()->value > 0 && + c1.Eval()->value >= 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0)); @@ -606,7 +647,13 @@ Mutate_(const Mod* op, const Expr& self) { c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual((y * c1).Eval(), 0)); + + // canonicalization: x % c == x % (-c) for truncated division + // NOTE: trunc div required + TVM_TRY_RECURSIVE_REWRITE_IF(x % c1, + x % PConst(make_const(op->type, -c1.Eval()->value)), + c1.Eval()->value < 0); // try modular analysis if ((x % c1).Match(ret)) { @@ -766,7 +813,9 @@ Mutate_(const Min* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE(min(c1 - x, c2), c1 - max(x, c2 - c1)); + TVM_TRY_RECURSIVE_REWRITE_IF( + min(c1 - x, c2), c1 - max(x, c1 - c2), + c2.Eval()->value != 0); } // condition rules. @@ -914,7 +963,8 @@ Mutate_(const Max* op, const Expr& self) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE(max(c1 - x, c2), c1 - min(x, c2 - c1)); + TVM_TRY_RECURSIVE_REWRITE_IF( + max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. @@ -1025,20 +1075,53 @@ Mutate_(const LT* op, const Expr& self) { TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); - // require c1 > 0 to work for any div mode TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, + c1.Eval()->value <= 0 && + c2.Eval()->value > 0); + // NOTE: trunc div required (euclidean is ok too, floored is not) + TVM_TRY_REWRITE_IF(x * c2 < c1, (c1 - 1) / c2 - 1 < x, c1.Eval()->value > 0 && + c2.Eval()->value < 0); + // NOTE: trunc div required (floored is ok too, euclidean is not) + TVM_TRY_REWRITE_IF(x * c2 < c1, c1 / c2 < x, + c1.Eval()->value <= 0 && + c2.Eval()->value < 0); + + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(c1 < x * c2, (c1 + 1) / c2 - 1 < x, + c1.Eval()->value < 0 && c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < x * c2, c1 / c2 < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); + // NOTE: trunc div required (floored is ok too, euclidean is not) + TVM_TRY_REWRITE_IF(c1 < x * c2, x < (c1 + 1) / c2 + 1, + c1.Eval()->value < 0 && + c2.Eval()->value < 0); + // NOTE: trunc div required (euclidean is ok too, floored is not) + TVM_TRY_REWRITE_IF(c1 < x * c2, x < c1 / c2, + c1.Eval()->value >= 0 && + c2.Eval()->value < 0); + + TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, + c1.Eval()->value > 0 && + c2.Eval()->value > 0); + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * (c2 - 1) + 1, + c1.Eval()->value > 0 && + c2.Eval()->value <= 0); + TVM_TRY_REWRITE_IF(c1 < x / c2, (c1 + 1) * c2 - 1 < x, c1.Eval()->value >= 0 && c2.Eval()->value > 0); + // NOTE: trunc div required + TVM_TRY_REWRITE_IF(c1 < x / c2, c1 * c2 < x, + c1.Eval()->value < 0 && + c2.Eval()->value > 0); // division related simplificationx // invariance for any div mod: x - (x / c1) * c1 == x % c1 @@ -1200,11 +1283,11 @@ Mutate_(const Select* op, const Expr& self) { Expr cond = Mutate(op->condition); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->true_value); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->false_value); } if (is_zero(cond)) { @@ -1237,11 +1320,11 @@ Mutate_(const Call* op, const Expr& self) { Expr cond = Mutate(op->args[0]); Expr true_value, false_value; { - ConstraintContext constraint(parent_, cond); + With constraint(parent_, cond); true_value = Mutate(op->args[1]); } { - ConstraintContext constraint(parent_, Mutate(Not::make(cond))); + With constraint(parent_, Mutate(Not::make(cond))); false_value = Mutate(op->args[2]); } if (is_zero(cond)) { diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index c793214b92f4..403187eb39fd 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator { Expr condition = this->Mutate(op->condition); Stmt then_case, else_case; { - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); then_case = this->Mutate(op->then_case); } if (op->else_case.defined()) { - ConstraintContext ctx(&analyzer_, Mutate(Not::make(condition))); + With ctx(&analyzer_, Mutate(Not::make(condition))); else_case = this->Mutate(op->else_case); } if (is_one(condition)) return then_case; @@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator { Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); - ConstraintContext ctx(&analyzer_, condition); + With ctx(&analyzer_, condition); Stmt body = this->Mutate(op->body); if (condition.same_as(op->condition) && diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 57e300fafec2..0a488f38457b 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Compile executable modules. * \file build_module.cc */ @@ -84,7 +83,7 @@ Target CreateTarget(const std::string& target_name, t->device_type = kDLGPU; t->keys_array.push_back(ir::StringImm::make("cuda")); t->keys_array.push_back(ir::StringImm::make("gpu")); - t->max_num_threads = 512; + t->max_num_threads = 1024; t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { // For now assume rocm schedule for opencl @@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate") TVM_REGISTER_API("_TargetFromString") .set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; - - *ret = Target::create(target_str); + *ret = Target::Create(target_str); }); std::vector TargetNode::keys() const { @@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) { return ""; } -Target Target::create(const std::string& target_str) { +Target Target::Create(const std::string& target_str) { if (target_str.length() == 0) { LOG(ERROR) << "target_str must not be empty"; } @@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) { struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ std::stack context_stack; - - TVMTargetThreadLocalEntry() { - } }; /*! \brief Thread local store to hold the Target context stack. */ typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; -void Target::EnterTargetScope(const tvm::Target& target) { +void Target::EnterWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); - entry->context_stack.push(target); + entry->context_stack.push(*this); } -void Target::ExitTargetScope() { +void Target::ExitWithScope() { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } -tvm::Target Target::current_target(bool allow_not_defined) { +tvm::Target Target::Current(bool allow_not_defined) { TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); @@ -311,7 +308,7 @@ bool LLVMEnabled() { /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { - if (target->device_type == kDLCPU) { + if (target.defined() && target->device_type == kDLCPU) { return target; } else { if (LLVMEnabled()) { @@ -392,7 +389,11 @@ Stmt BuildStmt(Schedule sch, if (loop_partition) { stmt = ir::LoopPartition(stmt, config->partition_const_loop); } - stmt = ir::VectorizeLoop(stmt); + if (config->disable_vectorize) { + stmt = ir::SkipVectorize(stmt); + } else { + stmt = ir::VectorizeLoop(stmt); + } stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); stmt = ir::StorageRewrite(stmt); @@ -570,7 +571,7 @@ runtime::Module build(const Map>& inputs, const BuildConfig& config) { Map> updated_input; for (const auto& it : inputs) { - auto target = Target::create(it.first); + auto target = Target::Create(it.first); updated_input.Set(target, it.second); } return build(updated_input, target_host, config); @@ -585,33 +586,35 @@ runtime::Module build(const Array& funcs, return build(inputs, target_host, config); } -BuildConfig build_config() { +BuildConfig BuildConfig::Create() { return BuildConfig(make_node()); } /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMBuildConfigThreadLocalEntry { /*! \brief The default build config if the stack is empty */ - tvm::BuildConfig default_config; + BuildConfig default_config; /*! \brief The current build config context */ - std::stack context_stack; + std::stack context_stack; TVMBuildConfigThreadLocalEntry() : - default_config(build_config()) { + default_config(BuildConfig::Create()) { } }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; -void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) { +void BuildConfig::EnterWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - entry->context_stack.push(build_config); + entry->context_stack.push(*this); } -void BuildConfig::ExitBuildConfigScope() { +void BuildConfig::ExitWithScope() { TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); + CHECK(!entry->context_stack.empty()); + CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } @@ -642,6 +645,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; + p->stream << "disable_vectorize=" << op->disable_vectorize; p->stream << ")"; }); @@ -709,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector& tags, void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { auto node = static_cast(node_.get()); - auto target = Target::current_target(true); + auto target = Target::Current(true); PackedFunc func; if (target.defined()) { @@ -735,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig") *ret = BuildConfig::Current(); }); +class BuildConfig::Internal { + public: + static void EnterScope(BuildConfig target) { + target.EnterWithScope(); + } + static void ExitScope(BuildConfig target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig target = args[0]; - BuildConfig::EnterBuildConfigScope(target); - }); +.set_body_typed(BuildConfig::Internal::EnterScope); TVM_REGISTER_API("_ExitBuildConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig::ExitBuildConfigScope(); - }); +.set_body_typed(BuildConfig::Internal::ExitScope); TVM_REGISTER_API("_BuildConfigSetAddLowerPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -831,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc") TVM_REGISTER_API("_GetCurrentTarget") .set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; - *ret = Target::current_target(allow_not_defined); + *ret = Target::Current(allow_not_defined); }); +class Target::Internal { + public: + static void EnterScope(Target target) { + target.EnterWithScope(); + } + static void ExitScope(Target target) { + target.ExitWithScope(); + } +}; + TVM_REGISTER_API("_EnterTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target target = args[0]; - Target::EnterTargetScope(target); - }); +.set_body_typed(Target::Internal::EnterScope); TVM_REGISTER_API("_ExitTargetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Target::ExitTargetScope(); - }); +.set_body_typed(Target::Internal::ExitScope); } // namespace tvm diff --git a/src/codegen/codegen_aocl.cc b/src/codegen/codegen_aocl.cc index 6f899cbb0b53..03b9b6869d17 100644 --- a/src/codegen/codegen_aocl.cc +++ b/src/codegen/codegen_aocl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array funcs, std::string target_str, std::string cmd = "aoc aocl.cl"; // AOCL supports fp64. cmd += " -Dcl_khr_fp64"; - Target target = Target::create(target_str); + Target target = Target::Create(target_str); if (target->device_name != "") { cmd += " -board=" + target->device_name; } diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index ef92f9ae3175..22dde1c46389 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -57,6 +57,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_math_constants_h_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -318,8 +322,19 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p) { / switch (op->type.bits()) { case 64: case 32: { std::ostringstream temp; - temp << std::scientific << op->value; - if (op->type.bits() == 32) temp << 'f'; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << ((op->type.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << ((op->type.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + p->need_math_constants_h_ = true; + } else { + temp << std::scientific << op->value; + if (op->type.bits() == 32) temp << 'f'; + } p->MarkConst(temp.str()); os << temp.str(); break; diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 381784a13a57..acd759f33889 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -39,7 +39,9 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); void AddFunction(LoweredFunc f); std::string Finish(); - bool need_include_path() { return (enable_fp16_ || enable_int8_); } + bool need_include_path() { + return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + } // override behavior void VisitStmt_(const ir::For* op) final; void PrintStorageSync(const Call* op) final; @@ -70,6 +72,9 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable int8 bool enable_int8_{false}; + // whether need math_constants.h + bool need_math_constants_h_{false}; + friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 382124a7ed2d..0b33bf43c151 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -247,6 +247,19 @@ void CodeGenOpenCL::VisitExpr_(const Select* op, std::ostream& os) { // NOLINT( CodeGenC::VisitExpr_(op, os); } +void CodeGenOpenCL::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) + if (std::isinf(op->value)) { + if (op->value < 0) { + os << "-"; + } + os << "INFINITY"; + } else if (std::isnan(op->value)) { + os << "NAN"; + } else { + CodeGenC::VisitExpr_(op, os); + } +} + runtime::Module BuildOpenCL(Array funcs) { using tvm::runtime::Registry; bool output_ssa = false; diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 0eff3a633ba3..36a55a545cbd 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -59,6 +59,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImm *op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index a18312fe6af5..4d86cc5b4b00 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { std::string xclbin; if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { - Target target = Target::create(target_str); + Target target = Target::Create(target_str); xclbin = (*f)(kernel_info, target->device_name).operator std::string(); } else { LOG(FATAL) << "Cannot compile Vivado HLS code."; diff --git a/src/codegen/datatype/registry.cc b/src/codegen/datatype/registry.cc new file mode 100644 index 000000000000..28cc58204e8d --- /dev/null +++ b/src/codegen/datatype/registry.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "registry.h" +#include + +namespace tvm { +namespace datatype { + +TVM_REGISTER_GLOBAL("_datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { + datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = datatype::Registry::Global()->GetTypeCode(args[0]); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeName(args[0].operator int()); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_registered").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); +}); + +Registry* Registry::Global() { + static Registry inst; + return &inst; +} + +void Registry::Register(const std::string& type_name, uint8_t type_code) { + CHECK(type_code >= kCustomBegin) << "Please choose a type code >= kCustomBegin for custom types"; + code_to_name_[type_code] = type_name; + name_to_code_[type_name] = type_code; +} + +uint8_t Registry::GetTypeCode(const std::string& type_name) { + CHECK(name_to_code_.find(type_name) != name_to_code_.end()) + << "Type name " << type_name << " not registered"; + return name_to_code_[type_name]; +} + +std::string Registry::GetTypeName(uint8_t type_code) { + CHECK(code_to_name_.find(type_code) != code_to_name_.end()) + << "Type code " << static_cast(type_code) << " not registered"; + return code_to_name_[type_code]; +} + +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target << "."; + ss << "Cast" + << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + ss << datatype::Registry::Global()->GetTypeName(type_code); + } else { + ss << runtime::TypeCode2Str(type_code); + } + + ss << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { + ss << datatype::Registry::Global()->GetTypeName(src_type_code); + } else { + ss << runtime::TypeCode2Str(src_type_code); + } + + return runtime::Registry::Get(ss.str()); +} + +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target; + ss << ".FloatImm."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + return runtime::Registry::Get(ss.str()); +} + +uint64_t ConvertConstScalar(uint8_t type_code, double value) { + std::ostringstream ss; + ss << "tvm.datatype.convertconstscalar.float."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + auto make_const_scalar_func = runtime::Registry::Get(ss.str()); + return (*make_const_scalar_func)(value).operator uint64_t(); +} + +} // namespace datatype +} // namespace tvm diff --git a/src/codegen/datatype/registry.h b/src/codegen/datatype/registry.h new file mode 100644 index 000000000000..d2e615765a18 --- /dev/null +++ b/src/codegen/datatype/registry.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_CODEGEN_DATATYPE_REGISTRY_H_ +#define TVM_CODEGEN_DATATYPE_REGISTRY_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace datatype { + +/*! + * \brief Registry for custom datatypes. + * + * Adding custom datatypes currently requires two steps: + * 1. Register the datatype with the registry via a call to + * datatype::Registry::Register. This can also be done in Python + * directly---see the TVM globals registered in the corresponding .cc file. + * Currently, user should manually choose a type name and a type code, + * ensuring that neither conflict with existing types. + * 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to + * lower the custom datatype. In general, these will look like: + * For Casts: tvm.datatype.lower..Cast.. + * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from + * float to myfloat. + * For other ops: tvm.datatype.lower... + * Examples: tvm.datatype.lower.llvm.Add.myfloat + * tvm.datatype.lower.llvm.FloatImm.posit + */ +class Registry { + public: + /*! + * \brief Get the global custom datatype registry singleton + */ + static Registry* Global(); + + /*! + * \brief Register custom datatype + * Register a custom datatype with the given type name and type code. Currently, the type code is + * manually allocated by the user, and the user must ensure that no two custom types share the + * same code. Generally, this should be straightforward, as the user will be manually registering + * all of their custom types. + * \param type_name The name of the type, e.g. "bfloat" + * \param type_code The type code, which should be greater than TVMTypeCode::kExtEnd + */ + void Register(const std::string& type_name, uint8_t type_code); + + /*! + * \brief Get type code from type name + * \param type_name The type name + * \return The type code + */ + uint8_t GetTypeCode(const std::string &type_name); + + /*! + * \brief Get type name from type code + * \param type_code The type code + * \return The type name + */ + std::string GetTypeName(uint8_t type_code); + + /*! + * \brief Get bool representing whether type is registered, given the type code + * \param type_code The type code + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(uint8_t type_code) { + return code_to_name_.find(type_code) != code_to_name_.end(); + } + + /*! + * \brief Get bool representing whether type is registered, given the type name + * \param type_name The type name + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(std::string type_name) { + return name_to_code_.find(type_name) != name_to_code_.end(); + } + + private: + // TODO(gus) is there a typedef for the code? + std::unordered_map code_to_name_; + std::unordered_map name_to_code_; +}; + +/*! + * \brief Convert scalar value to a custom datatype format + * \param type_code The custom datatype to convert to, specified by type code + * \param value The floating point value to convert + * \return The value, encoded in the bits of a uint64_t + */ +uint64_t ConvertConstScalar(uint8_t type_code, double value); + +/*! + * \brief Get lowering function for Cast ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype being cast to + * \param src_type_code The datatype being cast from + * \return Lowering function for Cast ops for the provided target, type, and source type + */ +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code); + +/*! + * \brief Get lowering function for FloatImms + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the FloatImm + * \return Lowering function for FloatImms for the provided target and type + */ +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code); + +/*! + * \brief Get lowering function for other ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the op + * \return Lowering function for other ops for the provided target and type + */ +#define DEFINE_GET_LOWER_FUNC_(OP) \ + inline const runtime::PackedFunc* Get##OP##LowerFunc(const std::string& target, \ + uint8_t type_code) { \ + return runtime::Registry::Get("tvm.datatype.lower." + target + "." #OP "." + \ + datatype::Registry::Global()->GetTypeName(type_code)); \ + } + +DEFINE_GET_LOWER_FUNC_(Add) +DEFINE_GET_LOWER_FUNC_(Sub) +DEFINE_GET_LOWER_FUNC_(Mul) +DEFINE_GET_LOWER_FUNC_(Div) +DEFINE_GET_LOWER_FUNC_(Mod) +DEFINE_GET_LOWER_FUNC_(Min) +DEFINE_GET_LOWER_FUNC_(Max) +DEFINE_GET_LOWER_FUNC_(EQ) +DEFINE_GET_LOWER_FUNC_(NE) +DEFINE_GET_LOWER_FUNC_(LT) +DEFINE_GET_LOWER_FUNC_(LE) +DEFINE_GET_LOWER_FUNC_(GT) +DEFINE_GET_LOWER_FUNC_(GE) +// Later changes may need to add more lowering functions as we support workloads with more ops. + +} // namespace datatype +} // namespace tvm + +#endif // TVM_CODEGEN_DATATYPE_REGISTRY_H_ diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 7946f906125f..1e56583a37fd 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -537,6 +537,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) { if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); + } else if (to.is_uint() && to.bits() == 1) { + if (from.is_float()) { + llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.); + return builder_->CreateFCmpONE(value, zero); + } else { + llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0); + return builder_->CreateICmpNE(value, zero); + } } else if (!from.is_float() && !to.is_float()) { return builder_->CreateIntCast(value, target, from.is_int()); } else if (from.is_float() && to.is_int()) { @@ -1134,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc index fda239f0766f..e2a788f1bbd4 100644 --- a/src/codegen/opt/build_cuda_on.cc +++ b/src/codegen/opt/build_cuda_on.cc @@ -84,12 +84,13 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; nvrtcProgram prog; - cudaDeviceProp device_prop; std::string cc = "30"; - cudaError_t e = cudaGetDeviceProperties(&device_prop, 0); + int major, minor; + cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); + cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); - if (e == cudaSuccess) { - cc = std::to_string(device_prop.major) + std::to_string(device_prop.minor); + if (e1 == cudaSuccess && e2 == cudaSuccess) { + cc = std::to_string(major) + std::to_string(minor); } else { LOG(WARNING) << "cannot detect compute capability from your device, " << "fall back to compute_30."; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index e6fc0088dc81..fd113ca4614a 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { - arith::ConstraintContext cctx(analyzer_.get(), op->condition); + With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } diff --git a/src/common/socket.h b/src/common/socket.h index 58705f16bf73..91f9f4e5cf0a 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -373,7 +373,7 @@ class TCPSocket : public Socket { } /*! * \brief decide whether the socket is at OOB mark - * \return 1 if at mark, 0 if not, -1 if an error occured + * \return 1 if at mark, 0 if not, -1 if an error occurred */ int AtMark() const { #ifdef _WIN32 diff --git a/src/contrib/cblas/cblas.cc b/src/contrib/cblas/cblas.cc index 4ca043f1bcfe..0f222e2f2a39 100644 --- a/src/contrib/cblas/cblas.cc +++ b/src/contrib/cblas/cblas.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,12 +21,11 @@ * Copyright (c) 2017 by Contributors * \file Use external cblas library call. */ +#include #include #include -#include #include "gemm_common.h" - extern "C" { #if USE_MKL_BLAS == 1 #include @@ -40,56 +39,148 @@ namespace contrib { using namespace runtime; -inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { - return trans ? CblasTrans : CblasNoTrans; -} +inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } struct CblasSgemmOp { typedef float TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - cblas_sgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; struct CblasDgemmOp { typedef double TDatatype; - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - cblas_dgemm(CblasColMajor, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - alpha, A, lda, - B, ldb, - beta, C, ldc); + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); } }; +struct CblasSgemmBatchOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_sgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; + +struct CblasSgemmBatchIterativeOp { + typedef float TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, + int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_sgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + +struct CblasDgemmBatchOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); +#if USE_MKL_BLAS == 1 + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = A + i * a_stride; + B_array[i] = B + i * b_stride; + C_array[i] = C + i * c_stride; + } + cblas_dgemm_batch(CblasColMajor, &trans_a, &trans_b, &M, &N, &K, &alpha, A_array.data(), &lda, + B_array.data(), &ldb, &beta, C_array.data(), &ldc, 1, &batch_size); +#else + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } +#endif + } +}; + +struct CblasDgemmBatchIterativeOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || - TypeMatch(A->dtype, kDLFloat, 64)); +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + + if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CblasSgemmOp()); + else + CallGemm(args, ret, CblasDgemmOp()); +}); - if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CblasSgemmOp()); - else - CallGemm(args, ret, CblasDgemmOp()); - }); +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchOp()); + } +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/cblas/gemm_common.h b/src/contrib/cblas/gemm_common.h index fe38b2a67513..2bcefb2f26bb 100644 --- a/src/contrib/cblas/gemm_common.h +++ b/src/contrib/cblas/gemm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,16 +22,17 @@ * \file tvm/contrib/gemm.h * \brief Shared implementation of gemm */ -#ifndef TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ -#define TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ +#pragma once + +#include +#include #include namespace tvm { namespace contrib { using namespace runtime; - -inline int ColumnStride(DLTensor* tensor) { +inline int ColumnStride(DLTensor *tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,8 +43,7 @@ inline int ColumnStride(DLTensor* tensor) { } } - -inline int ElementStride(DLTensor* tensor) { +inline int ElementStride(DLTensor *tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,29 +51,26 @@ inline int ElementStride(DLTensor* tensor) { } } - // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor* tensor) { +inline bool IsInPlaceTransposed(DLTensor *tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } - -inline int RowCount(DLTensor* tensor, bool trans) { +inline int RowCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } - -inline int ColumnCount(DLTensor* tensor, bool trans) { +inline int ColumnCount(DLTensor *tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. -template +template inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -96,25 +93,88 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, - transa, - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - static_cast(alpha), - reinterpret_cast(static_cast(B->data) - + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), static_cast(alpha), + reinterpret_cast( + static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast(static_cast(A->data) - + A->byte_offset), - ColumnStride(A), - static_cast(beta), - reinterpret_cast(static_cast(C->data) - + C->byte_offset), + reinterpret_cast( + static_cast(A->data) + A->byte_offset), + ColumnStride(A), static_cast(beta), + reinterpret_cast( + static_cast(C->data) + C->byte_offset), ColumnStride(C)); } +inline int ColumnStride3D(DLTensor *tensor) { + // If the tensor itself is transposed then it will have strides + // backward from what we expect. Regardless, the max of the strides + // (the other stride is 1) is the column stride. + if (tensor->strides) { + return std::max(tensor->strides[1], tensor->strides[2]); + } else { + return tensor->shape[2]; + } +} +inline int ElementStride3D(DLTensor *tensor) { + if (tensor->strides) { + return std::min(tensor->strides[1], tensor->strides[2]); + } else { + return 1; + } +} +// Reversed strides indicates an in-place transpose operation. +inline bool IsInPlaceTransposed3D(DLTensor *tensor) { + return tensor->strides && (tensor->strides[2] > tensor->strides[1]); +} +inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 2 : 1]; +} +inline int ColumnCount3D(DLTensor *tensor, bool trans) { + return tensor->shape[trans ? 1 : 2]; +} +template +inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { + using DType = typename TBatchGemmOp::TDatatype; + DLTensor *A = args[0]; + DLTensor *B = args[1]; + DLTensor *C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + CHECK_EQ(A->ndim, 3); + CHECK_EQ(B->ndim, 3); + CHECK_EQ(C->ndim, 3); + int batch_size = BatchCount3D(A); + CHECK_EQ(BatchCount3D(B), batch_size); + CHECK_EQ(BatchCount3D(C), batch_size); + CHECK_EQ(ElementStride(A), 1); + CHECK_EQ(ElementStride(B), 1); + CHECK_EQ(ElementStride(C), 1); + // C can never be transposed. + CHECK(!IsInPlaceTransposed3D(C)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed3D(A) ? !transa : transa; + transb = IsInPlaceTransposed3D(B) ? !transb : transb; + CHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); + CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; + double beta = args.size() > 6 ? args[6] : 0.0; + const int A_size = A->shape[1] * A->shape[2]; + const int B_size = B->shape[1] * B->shape[2]; + const int C_size = C->shape[1] * C->shape[2]; + DType *A_data = reinterpret_cast( + static_cast(A->data) + A->byte_offset); + DType *B_data = reinterpret_cast( + static_cast(B->data) + B->byte_offset); + DType *C_data = reinterpret_cast( + static_cast(C->data) + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast(alpha), + B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + static_cast(beta), C_data, C_size, ColumnStride3D(C)); +} + } // namespace contrib } // namespace tvm - -#endif // TVM_CONTRIB_CBLAS_GEMM_COMMON_H_ diff --git a/src/contrib/sort/sort.cc b/src/contrib/sort/sort.cc index cf25e89b9109..87691f254c5c 100644 --- a/src/contrib/sort/sort.cc +++ b/src/contrib/sort/sort.cc @@ -34,14 +34,14 @@ namespace contrib { using namespace runtime; template -bool CompareAscend(const std::pair& lhs, - const std::pair& rhs) { +bool CompareAscend(const std::pair& lhs, + const std::pair& rhs) { return lhs.second < rhs.second; } template -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { +bool CompareDescend(const std::pair& lhs, + const std::pair& rhs) { return lhs.second > rhs.second; } @@ -110,6 +110,41 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") } }); +template +void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { + auto data_ptr = static_cast(input->data); + auto out_ptr = static_cast(output->data); + std::vector > sorter; + + int axis_mul_before = 1; + int axis_mul_after = 1; + for (int i = 0; i < input->ndim; ++i) { + if (i < axis) { + axis_mul_before *= input->shape[i]; + } else if (i > axis) { + axis_mul_after *= input->shape[i]; + } + } + + for (int i = 0 ; i < axis_mul_before; ++i) { + for (int j = 0 ; j < axis_mul_after; ++j) { + sorter.clear(); + int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; + for (int64_t k = 0; k < input->shape[axis]; ++k) { + int64_t full_idx = base_idx + k * axis_mul_after; + sorter.emplace_back(std::make_pair(k, data_ptr[full_idx])); + } + if (is_ascend) { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + } + for (int64_t k = 0; k < input->shape[axis]; ++k) { + out_ptr[base_idx + k * axis_mul_after] = static_cast(sorter[k].first); + } + } + } +} // Argsort implemented C library sort. // Return indices of sorted tensor. @@ -124,25 +159,84 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") DLTensor *output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; - - auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - std::vector> sorter; - int64_t axis_mul_before = 1; - int64_t axis_mul_after = 1; - if (axis < 0) { axis = input->ndim + axis; } - - // Currently only supports input dtype to be float32. - CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float32."; - CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " << input->ndim; + + auto data_dtype = TVMType2String(input->dtype); + auto out_dtype = TVMType2String(output->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "int64") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float32") { + argsort(input, output, axis, is_ascend); + } else if (out_dtype == "float64") { + argsort(input, output, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } +}); +template +void topk(DLTensor* input, + DLTensor* out_values, + DLTensor* out_indices, + int k, + int axis, + bool is_ascend) { + DataType* data_ptr = static_cast(input->data); + DataType* values_ptr = (out_values == nullptr) ? nullptr : + static_cast(out_values->data); + IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : + static_cast(out_indices->data); + std::vector > sorter; + + int axis_mul_before = 1; + int axis_mul_after = 1; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { axis_mul_before *= input->shape[i]; @@ -150,27 +244,124 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") axis_mul_after *= input->shape[i]; } } + if (k < 1) { + k = input->shape[axis]; + } - int32_t current_sort_num = input->shape[axis]; - for (int64_t i = 0 ; i < axis_mul_before; ++i) { - for (int64_t j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0 ; i < axis_mul_before; ++i) { + for (int j = 0 ; j < axis_mul_after; ++j) { sorter.clear(); - int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; - for (int64_t k = 0; k < current_sort_num; ++k) { - int64_t full_idx = base_idx + k * axis_mul_after; - sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); + int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; + int64_t dst_base_idx = i * k * axis_mul_after + j; + for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { + int64_t full_idx = src_base_idx + kk * axis_mul_after; + sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); } if (is_ascend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } - for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) - = k < static_cast(sorter.size()) ? sorter[k].first : k; + int64_t cnt = k > 0 ? k : input->shape[axis]; + for (int64_t kk = 0; kk < cnt; ++kk) { + if (indices_ptr != nullptr) { + indices_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(sorter[kk].first); + } + if (values_ptr != nullptr) { + values_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(sorter[kk].second); + } } } } +} + +// Argsort implemented C library sort. +// Return indices of sorted tensor. +// By default, the last axis will be used to sort. +// sort_num specify the number of elements to be sorted. +// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) +// and sort axis is dk. sort_num should have dimension of +// (d1, d2, ..., d(k-1), d(k+1), ..., dn). +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* values_out = nullptr; + DLTensor* indices_out = nullptr; + int k = args[args.num_args - 4]; + int axis = args[args.num_args - 3]; + std::string ret_type = args[args.num_args - 2]; + bool is_ascend = args[args.num_args - 1]; + if (ret_type == "both") { + values_out = args[1]; + indices_out = args[2]; + } else if (ret_type == "values") { + values_out = args[1]; + } else if (ret_type == "indices") { + indices_out = args[1]; + } else { + LOG(FATAL) << "Unsupported ret type: " << ret_type; + } + if (axis < 0) { + axis = input->ndim + axis; + } + CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; + + auto data_dtype = TVMType2String(input->dtype); + auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype); + + if (data_dtype == "float32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "float64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int32") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else if (data_dtype == "int64") { + if (out_dtype == "int32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "int64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float32") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else if (out_dtype == "float64") { + topk(input, values_out, indices_out, k, axis, is_ascend); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } + } else { + LOG(FATAL) << "Unsupported input dtype: " << data_dtype; + } }); } // namespace contrib diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 4504ee23f812..3f5254069b8d 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) { return ir::Mod::make(a, b); } + Expr min(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return b; + if (is_neg_inf(a)) return a; + if (is_pos_inf(b)) return a; + if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) { } Expr max(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return a; + if (is_neg_inf(a)) return b; + if (is_pos_inf(b)) return b; + if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -393,6 +408,16 @@ Expr sum(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr all(Expr source, Array rdom) { + CHECK(source.type().is_bool()); + Var x("x", source.type()), y("y", source.type()); + Expr result = ir::And::make(x, y); + Expr identity_element = make_const(source.type(), true); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); +} + Expr max(Expr source, Array rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Max::make(x, y); diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index bab7cf6d93ed..d885d7103606 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -83,6 +83,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { @@ -91,6 +92,7 @@ TensorIntrin TensorIntrinNode::make(std::string name, n->op = std::move(op); n->inputs = std::move(inputs); n->buffers = std::move(buffers); + n->scalar_params = std::move(scalar_params); n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); @@ -110,12 +112,14 @@ TVM_REGISTER_NODE_TYPE(TensorIntrinNode); TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, Array tensors, Array regions, - Array reduce_axis) { + Array reduce_axis, + Array scalar_inputs) { auto n = make_node(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); + n->scalar_inputs = std::move(scalar_inputs); return TensorIntrinCall(n); } diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index e6c6039b610e..7023aebe17ad 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs.size(), input_placeholders.size()); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); - CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape)); + CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); + for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } n->inputs = std::move(inputs); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index ed768c2ba216..09e8af7d5cba 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -58,7 +58,8 @@ Operation TensorComputeOpNode::make(std::string name, int schedulable_ndim, TensorIntrin intrin, Array tensors, - Array regions) { + Array regions, + Array scalar_inputs) { auto n = make_node(); n->name = std::move(name); n->tag = std::move(tag); @@ -68,6 +69,7 @@ Operation TensorComputeOpNode::make(std::string name, n->intrin = std::move(intrin); n->inputs = std::move(tensors); n->input_regions = std::move(regions); + n->scalar_inputs = std::move(scalar_inputs); return Operation(n); } @@ -184,6 +186,19 @@ Stmt TensorComputeOpNode::BuildProvide( std::unordered_map vmap; ir::ArgBinder binder(&vmap); + // Map the expressions passed in the call to the TensorIntrin, to the placeholder + // variables + Array user_expr = this->scalar_inputs; + Array scalar_params = this->intrin->scalar_params; + Array sp_expr; + for (auto sp : scalar_params) { + Expr esp = sp; + sp_expr.push_back(esp); + } + CHECK_EQ(sp_expr.size(), user_expr.size()); + // TODO(jdavies-huawei): what name should be used here? + binder.BindArray(sp_expr, user_expr, this->name); + size_t tloc = stage->leaf_iter_vars.size(); ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); diff --git a/src/pass/arg_binder.h b/src/pass/arg_binder.h index 9de3a13270dc..f235ea49faac 100644 --- a/src/pass/arg_binder.h +++ b/src/pass/arg_binder.h @@ -50,7 +50,7 @@ namespace ir { * - assert bufferB.shape[1] == n + 3 * * In general, this is a constraint solving problem. We have simplified assumption - * over the binding declaration, such that we require the variable occured in + * over the binding declaration, such that we require the variable occurred in * constraint must be declared in argument list. So it is illegal to have signature * f(tA(shape=(n+3))) without any argument variable corresponds to n, even though * it is already enough to derive n from the input argument. diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 04bb9385b156..0a5b7410f3cf 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #include #include #include -#include "../arithmetic/int_set_internal.h" +#include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" namespace tvm { @@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator { std::pair> GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); @@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator { /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; std::unordered_map relax_map_; + arith::Analyzer analyzer_; CandidateSelector selector; }; @@ -381,29 +382,17 @@ class LoopPartitioner : public IRMutator { // given in the second component provably have value given by cond_value std::pair> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { - arith::Interval interval = kv.second.as()->i; - auto intersection = arith::Interval::make_intersection(interval, for_interval); - - // TODO(derisavi): the following if statement needs to be removed as soon as - // TVM uses commit a768f2f0 of HalideIR repo - if (intersection.min.same_as(arith::Interval::pos_inf) || - intersection.max.same_as(arith::Interval::neg_inf)) { - intersection = arith::Interval::nothing(); - } else if (intersection.min.type() == intersection.max.type() && - (intersection.min.type().is_int() || - intersection.min.type().is_uint()) && - can_prove(intersection.min > intersection.max)) { - intersection = arith::Interval::nothing(); - } - - if (!intersection.is_empty()) { + arith::IntervalSet interval = Downcast(kv.second); + arith::IntervalSet intersection = arith::Intersect( + &analyzer_, interval, for_interval); + if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); } @@ -476,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr max, Stmt body, bool partition_thread_scope) { + using namespace arith; PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); if (finder.partitions.empty()) return Stmt(); - arith::Interval for_interval(min, max); + arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; std::unordered_set cond_set; @@ -491,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // if such interval doesn't exist, find an interval in which all // conditions on var are false std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, false); + GetIntervalAndCondset(finder.partitions, for_interval, false); if (middle_interval.is_nothing()) // we couldn't find an interval in which the condintions are provably true or false // Therefore, we can't partition the loop based on those conds @@ -501,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, cond_value = true; } - arith::Interval middle_interval_i = middle_interval.as()->i; + IntervalSet middle_interval_i = Downcast(middle_interval); // middle_interval is the subrange of the loop variable range for which a // set of conditions are true (or false resp.) // The part of the loop variable range that is before (after resp.) that @@ -512,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr body_begin; Stmt pre_stmt; bool pre_stmt_recurse = true; - if (middle_interval_i.has_lower_bound()) { + if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); if (!can_prove(body_begin == min)) { Expr cond = (body_begin - min >= 0); @@ -537,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; - if (middle_interval_i.has_upper_bound()) { + if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); if (!can_prove(middle_interval.max() == max)) { // require the extent to be non-negative diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc new file mode 100644 index 000000000000..7598ef49eee0 --- /dev/null +++ b/src/pass/lower_custom_datatypes.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/src/pass/lower_custom_datatypes.cc + * \brief Pass for lowering custom datatypes + */ + +#include +#include +#include +#include "../codegen/datatype/registry.h" + +namespace tvm { +namespace ir { + +/*! + * \brief Helper mutator to implement lowering of custom datatypes. + * + * Lowering datatypes works as follows: for every expression containing a custom + * datatype, we search for a global (registered by the implementer of the custom + * datatype) for lowering this type of expression, and uses it to lower the + * expression. + */ +class CustomDatatypesLowerer : public IRMutator { + public: + explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} + + inline Expr Mutate_(const Cast* op, const Expr& e) final { + auto type_code = op->type.code(); + auto src_type_code = op->value.type().code(); + // If either datatype is a registered custom datatype, we must lower. + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || + datatype::Registry::Global()->GetTypeRegistered(src_type_code); + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + if (toBeLowered) { + auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); + CHECK(lower) << "Cast lowering function for target " << target_ << " destination type " + << static_cast(type_code) << " source type " + << static_cast(src_type_code) << " not found"; + return (*lower)(expr); + } + return expr; + } + + inline Expr Mutate_(const FloatImm* imm, const Expr& e) final { + auto type_code = imm->type.code(); + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); + CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " + << static_cast(type_code) << " not found"; + return (*lower)(e); + } + return e; + } + + inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->type.code()); + Stmt stmt = IRMutator::Mutate_(allocate, s); + allocate = stmt.as(); + + if (toBeLowered) { + auto new_allocate_type = UInt(allocate->type.bits(), allocate->type.lanes()); + return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body, allocate->new_expr, + allocate->free_function); + } + return stmt; + } + + inline Expr Mutate_(const Load* load, const Expr& e) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->type.code()); + Expr expr = IRMutator::Mutate_(load, e); + load = expr.as(); + if (toBeLowered) { + auto new_load_type = UInt(load->type.bits()); + return Load::make(new_load_type, load->buffer_var, load->index, load->predicate); + } + return expr; + } + +#define DEFINE_MUTATE__(OP) \ + inline Expr Mutate_(const OP* op, const Expr& e) final { \ + auto type_code = op->type.code(); \ + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ + Expr expr = IRMutator::Mutate_(op, e); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ + } + + DEFINE_MUTATE__(Add) + DEFINE_MUTATE__(Sub) + DEFINE_MUTATE__(Mul) + DEFINE_MUTATE__(Div) + DEFINE_MUTATE__(Mod) + DEFINE_MUTATE__(Min) + DEFINE_MUTATE__(Max) + DEFINE_MUTATE__(EQ) + DEFINE_MUTATE__(NE) + DEFINE_MUTATE__(LT) + DEFINE_MUTATE__(LE) + DEFINE_MUTATE__(GT) + DEFINE_MUTATE__(GE) + // Later changes may need to add more mutate functions as we support workloads with more ops. + + private: + std::string target_; +}; + +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { + auto n = make_node(*f.operator->()); + n->body = CustomDatatypesLowerer(target).Mutate(n->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index f87e80c2d030..8c3d383c1529 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -519,5 +519,23 @@ Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer().Mutate(stmt); } +class VectorizeSkipper : public IRMutator { + public: + Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + if (op->for_type == ForType::Vectorized) { + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, + op->body); + } else { + return stmt; + } + } +}; + +Stmt SkipVectorize(Stmt stmt) { + return VectorizeSkipper().Mutate(stmt); +} + } // namespace ir } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 08a88d53350f..3feb7e4a4b54 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -18,18 +18,13 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ - #include -#include +#include #include -#include -#include -#include -#include +#include #include #include "utils.h" @@ -38,88 +33,8 @@ namespace tvm { namespace relay { namespace backend { -/*! - * \brief Context name / index - * See: python/tvm/_ffi/runtime_ctypes.py - */ -struct ContextMap { - static const std::unordered_map mask2str; - static const std::unordered_map str2mask; - static std::string Mask2Str(int mask) { - CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; - return mask2str.at(mask); - } - static int Str2Mask(const std::string& str) { - CHECK_GT(str2mask.count(str), 0) << "Unknown context."; - return str2mask.at(str); - } -}; - -const std::unordered_map ContextMap::mask2str = { - {1, "cpu"}, - {2, "gpu"}, - {4, "opencl"}, - {5, "aocl"}, - {6, "sdaccel"}, - {7, "vulkan"}, - {8, "metal"}, - {9, "vpi"}, - {10, "rocm"}, - {11, "opengl"}, - {12, "ext_dev"} -}; - -const std::unordered_map ContextMap::str2mask = { - {"llvm", 1}, - {"cpu", 1}, - {"c", 1}, - {"gpu", 2}, - {"cuda", 2}, - {"nvptx", 2}, - {"cl", 4}, - {"opencl", 4}, - {"aocl", 5}, - {"aocl_sw_emu", 5}, - {"vulkan", 7}, - {"metal", 8}, - {"vpi", 9}, - {"rocm", 10}, - {"opengl", 11}, - {"ext_dev", 12} -}; - -/*! - * \brief A data structure to map the names of specific optimizations to - * numeric optimization levels - * - */ -struct OptPassLevel { - static const std::unordered_map _data; - /*! - * \brief Get level for an optimization pass - * - * \param key pass name - * \return int level - */ - int operator[](const std::string& key) const { - auto it = _data.find(key); - if (it == _data.end()) { - return -1; - } - return it->second; - } -}; - -const std::unordered_map OptPassLevel::_data = { - {"SimplifyInference", 0}, - {"OpFusion", 1}, - {"FoldConstant", 2}, - {"CombineParallelConv2D", 3}, - {"FoldScaleAxis", 3}, - {"AlterOpLayout", 3}, - {"CanonicalizeOps", 3}, - {"EliminateCommonSubexpr", 3} -}; +using TargetsMap = Map; +using namespace tvm::relay::transform; /*! * \brief Output of building module @@ -131,27 +46,6 @@ struct BuildOutput { std::unordered_map params; }; -/*! - * \brief Relay building config - * - */ -struct RelayBuildConfig { - int opt_level{2}; - std::string fallback_device{"llvm"}; - std::unordered_set enabled_pass; - std::unordered_set disabled_pass; - OptPassLevel OPT_PASS_LEVEL; - inline bool pass_enabled(const std::string& pass_name) const { - if (disabled_pass.count(pass_name)) { - return false; - } - if (enabled_pass.count(pass_name)) { - return true; - } - return opt_level >= OPT_PASS_LEVEL[pass_name]; - } -}; - /*! * \brief GraphCodegen module wrapper * @@ -164,14 +58,8 @@ struct GraphCodegen { } ~GraphCodegen() {} - void Init(runtime::Module* m, - Map targets) { - Array tgts; - for (auto kv : targets) { - tgts.push_back(kv.first); - tgts.push_back(kv.second); - } - CallFunc("init", m, tgts); + void Init(runtime::Module* m, TargetsMap targets) { + CallFunc("init", m, targets); } void Codegen(const Function& func) { @@ -211,18 +99,6 @@ struct GraphCodegen { } }; -template -R CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - -template -Function CallPackedFunc(const std::string &name, Args... args) { - auto pf = GetPackedFunc(name); - return (*pf)(std::forward(args)...); -} - /*! * \brief Relay build module * @@ -248,14 +124,7 @@ class RelayBuildModule : public runtime::ModuleNode { } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); - Array tmp = args[1]; - std::unordered_map targets; - for (size_t i = 0; i < tmp.size(); i += 2) { - auto k = tmp[i].as()->value; - auto v = tmp[i + 1].as()->value; - targets[k] = v; - } - this->Build(args[0], targets, args[2]); + this->Build(args[0], args[1], args[2]); }); } else if (name == "list_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -265,27 +134,6 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); - } else if (name == "set_opt_level") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 1); - int level = args[0]; - this->SetOptLevel(level); - }); - } else if (name == "set_fallback_device") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string dev = args[0]; - this->SetFallBackDev(dev); - }); - } else if (name == "add_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->AddPass(pass_name); - }); - } else if (name == "disable_pass") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::string pass_name = args[0]; - this->DisablePass(pass_name); - }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map params = args[0]; @@ -307,30 +155,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::string& GetGraphJSON() { return ret_.graph_json; } - /*! - * \brief Add extra pass into build cfg - * - * \param pass_name name of pass - */ - void AddPass(const std::string& pass_name) { - cfg_.enabled_pass.insert(pass_name); - } - /*! - * \brief Disable a specific pass in cfg - * - * \param pass_name name of pass - */ - void DisablePass(const std::string& pass_name) { - cfg_.disabled_pass.insert(pass_name); - } - /*! - * \brief Set the Fallback device - * - * \param device name - */ - void SetFallBackDev(const std::string& dev) { - cfg_.fallback_device = dev; - } + /*! * \brief Get the Module object * @@ -345,8 +170,8 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return Array names of params */ - Array ListParamNames() { - Array ret; + Array ListParamNames() { + Array ret; for (const auto& kv : params_) { ret.push_back(ir::StringImm::make(kv.first)); } @@ -376,15 +201,6 @@ class RelayBuildModule : public runtime::ModuleNode { params_[name] = data_in; } - /*! - * \brief Set the optimization level - * - * \param level - */ - void SetOptLevel(char level) { - cfg_.opt_level = level; - } - /*! * \brief type key * @@ -402,11 +218,11 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target_host Host target device */ void Build(Function func, - const std::unordered_map& targets, - const std::string& target_host) { + const TargetsMap& targets, + const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, cfg_, params_); + BuildRelay(func, params_); } protected: @@ -416,8 +232,9 @@ class RelayBuildModule : public runtime::ModuleNode { * \param params params dict * \return relay::Function */ - relay::Function BindParamsByName(relay::Function func, - const std::unordered_map& params) { + relay::Function BindParamsByName( + relay::Function func, + const std::unordered_map& params) { std::unordered_map name_dict; std::unordered_set repeat_var; for (auto arg : func->params) { @@ -438,149 +255,147 @@ class RelayBuildModule : public runtime::ModuleNode { if (repeat_var.count(arg)) { LOG(FATAL) << "Multiple args in the function have name " << kv.first; } - auto e = CallPackedFunc("relay._make.Constant", kv.second); - bind_dict[arg] = e; + bind_dict[arg] = ConstantNode::make(kv.second); } - return CallPackedFunc("relay._expr.Bind", func, tvm::Map(bind_dict)); + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; } /*! - * \brief Optimize Relay function + * \brief Optimize a Relay module. * - * \param func Input function - * \param target target device - * \param cfg Relay build config - * \param params params dict - * \return relay::Function + * \param relay_module The input Relay module where optmization will be + * applied on. + * \param targets The device type to `Target` mapping. + * \param params The param name to value mapping. + * + * \return relay::Module The updated Relay module after optimization. */ - relay::Function Optimize(relay::Function func, - const std::unordered_map& targets, - const RelayBuildConfig& cfg, - const std::unordered_map& params) { - if (params.size()) { - func = BindParamsByName(func, params); - } - if (cfg.pass_enabled("SimplifyInference")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.simplify_inference", func); - } - if (cfg.pass_enabled("EliminateCommonSubexpr")) { - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } + relay::Module Optimize( + relay::Module relay_module, + const TargetsMap& targets, + const std::unordered_map& params) { + Array pass_seqs; + pass_seqs.push_back(transform::SimplifyInference()); + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; } } - *rv = false; - }); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.eliminate_common_subexpr", func, fskip); - } - if (cfg.pass_enabled("CombineParallelConv2D")) { - const int min_num_branches = 3; - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.CombineParallelConv2D", func, min_num_branches); - } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("FoldScaleAxis")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.backward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.forward_fold_scale_axis", func); - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); - } - if (cfg.pass_enabled("CanonicalizeOps")) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.canonicalize_ops", func); + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + pass_seqs.push_back(transform::AlterOpLayout()); } - if (cfg.pass_enabled("AlterOpLayout")) { - if (targets.size() == 1) { - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - auto enter_pf = GetPackedFunc("_EnterTargetScope"); - auto exit_pf = GetPackedFunc("_ExitTargetScope"); - for (const auto& kv : targets) { - auto target = Target::create(kv.second); - (*enter_pf)(target); - func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - (*exit_pf)(); - } - } else { - LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" - << " execution yet."; + pass_seqs.push_back(transform::FoldConstant()); + + // Create a sequential pass and perform optimizations. + transform::Pass seq = transform::Sequential(pass_seqs); + if (targets.size() == 1) { + for (const auto& kv : targets) { + With tctx(kv.second); + relay_module = seq(relay_module); } + } else { + relay_module = seq(relay_module); } - if (cfg.pass_enabled("FoldConstant")) { - func = CallPackedFunc("relay._ir_pass.FoldConstant", func); + + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + if (targets_.size() > 1) { + relay_module = + RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); } - return func; + + // Fuse the operations if it is needed. + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + + return relay_module; } + + /*! + * \brief Create a default type. + * \param device_type The device type index. + * \return the default target for the device. + */ + Target CreateDefaultTarget(int device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") return Target::Create("llvm"); + if (name == "gpu") return Target::Create("cuda"); + return Target::Create(name); + } + /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. * Meanwhile, a CPU device type and "llvm" pair will be added to the target * dictionary in this case. * - * \param targets dictionary - * \param cfg - * \return Map + * \param fallback_device The fallback device for heterogeneous execution. */ - Map UpdateHeterogeneousInputs( - const std::unordered_map& targets, - const RelayBuildConfig& cfg) { - Map device_target; - std::unordered_map tmp_map; - auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); - - for (const auto& kv : targets) { - tmp_map[ContextMap::Str2Mask(kv.first)] = kv.second; + void UpdateHeterogeneousInputs(int fallback_device) { + std::unordered_map tmp_map; + for (const auto& kv : targets_) { + tmp_map[kv.first->value] = kv.second; } - if (tmp_map.count(fallback_idx) == 0) { - tmp_map[fallback_idx] = cfg.fallback_device; + if (tmp_map.count(fallback_device) == 0) { + targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); } - for (const auto& kv : tmp_map) { - device_target.Set( - ir::IntImm::make(HalideIR::Int(64), kv.first), - ir::StringImm::make(kv.second)); - } - return device_target; } + /*! * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func - * \param cfg - * \param targets_map_ptr - * \return Function + * \param relay_module The input Relay module. + * \param fallback_device The fallback device for heterogeneous execution. + * + * \return updated_module The updated module after device annotation. */ - Function RunDeviceAnnotationPass( - Function func, - const RelayBuildConfig& cfg, - Map* targets_map_ptr) { - auto fallback_idx = ContextMap::Str2Mask(cfg.fallback_device); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, fallback_idx); - auto device_map = CallPackedFunc >("relay._ir_pass.CollectDeviceInfo", - func, - nullptr); - if (device_map.size() == 0) { - auto annotation_map = - CallPackedFunc >("relay._ir_pass.CollectDeviceAnnotationOps", - func, - nullptr); - if (annotation_map.size() == 0) { - targets_map_ptr->Set( - ir::IntImm::make(HalideIR::Int(64), 0), - ir::StringImm::make(cfg.fallback_device)); + relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, + int fallback_device) { + UpdateHeterogeneousInputs(fallback_device); + auto rewrite = transform::RewriteAnnotatedOps(fallback_device); + auto updated_module = rewrite(relay_module); + CHECK(updated_module.defined()); + + tvm::Map device_map; + for (const auto& it : updated_module->functions) { + device_map = relay::CollectDeviceInfo(it.second); + if (!device_map.empty()) break; + } + + if (device_map.empty()) { + tvm::Map annotation_map; + for (const auto& it : relay_module->functions) { + annotation_map = relay::CollectDeviceAnnotationOps(it.second); + if (!annotation_map.empty()) break; + } + // None op is annotated but they are fallen back to the default device. + if (annotation_map.empty()) { + targets_.Set(0, CreateDefaultTarget(fallback_device)); } else { + // All ops are annotated to the same device type. int64_t dev_type = -1; for (auto kv : annotation_map) { dev_type = kv.second->value; @@ -594,70 +409,54 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set( - ir::IntImm::make(HalideIR::Int(64), 0), - ir::StringImm::make(ContextMap::Mask2Str(dev_type))); + targets_.Set(0, CreateDefaultTarget(dev_type)); } } - return func; + return updated_module; } /*! * \brief Build relay function to runtime module * * \param func Relay Function - * \param cfg Relay build config * \param params parameters */ - void BuildRelay(Function func, - const RelayBuildConfig& cfg, - const std::unordered_map ¶ms) { - // convert - tvm_cfg_ = build_config(); - Map device_target; - if (targets_.size() > 1) { - device_target = UpdateHeterogeneousInputs(targets_, cfg); - } else { - for (auto &kv : targets_) { - device_target.Set( - ir::IntImm::make(HalideIR::Int(64), ContextMap::Str2Mask(kv.first)), - ir::StringImm::make(kv.second)); - } - } - func = Optimize(func, targets_, cfg, params); - if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, cfg, &device_target); + void BuildRelay( + Function func, + const std::unordered_map& params) { + if (params.size()) { + func = BindParamsByName(func, params); } - // TODO(@jroesch): use the passes directly. - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - func = CallPackedFunc("relay._ir_pass.FuseOps", func, cfg.opt_level, nullptr); - func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); + // Perform Module->Module optimizations. + relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay_module = Optimize(relay_module, targets_, params); + CHECK(relay_module.defined()); + // Get the updated function. + func = relay_module->Lookup(relay_module->entry_func->name_hint); + + // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - auto target_host = Target::create(target_host_); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, + BuildConfig::Current()); } protected: std::unique_ptr graph_codegen_; /*! \brief target device */ - std::unordered_map targets_; + TargetsMap targets_; /*! \brief target host device */ - std::string target_host_; - /*! \brief frontend optimization configure */ - RelayBuildConfig cfg_; + tvm::Target target_host_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! \brief tvm building cfg */ - BuildConfig tvm_cfg_; }; runtime::Module RelayBuildCreate() { @@ -665,7 +464,8 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") +.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); }); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index a824c457107a..f11dd2875b80 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_[key] = value; } // Enforce use the target. - TargetContext target_ctx(key->target); + With target_scope(key->target); CHECK(!value->cached_func.defined()); auto spair = CreateSchedule(key->source_func, key->target); @@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode { cache_node->funcs = (*f)( spair.first, all_args, cache_node->func_name, key->source_func); } else { - tvm::BuildConfig bcfg = tvm::build_config(); + tvm::BuildConfig bcfg = BuildConfig::Create(); std::unordered_map binds; cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 415e0ec9c2a5..b14448c59166 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -51,7 +51,7 @@ using GraphAttrs = std::unordered_map; using GraphNodePtr = std::shared_ptr; using GraphInputNodePtr = std::shared_ptr; using GraphOpNodePtr = std::shared_ptr; -using TargetsMap = std::unordered_map; +using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { @@ -193,12 +193,10 @@ class GraphOpNode : public GraphNode { class GraphRuntimeCodegen : public ::tvm::relay::ExprFunctor(const Expr&)> { public: - GraphRuntimeCodegen(runtime::Module* mod, - const std::unordered_map& targets) : mod_(mod) { + GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) + : mod_(mod) { compile_engine_ = CompileEngine::Global(); - for (auto &kv : targets) { - targets_[kv.first] = Target::create(kv.second); - } + targets_ = targets; } LoweredOutput Codegen(relay::Function func) { @@ -406,7 +404,7 @@ class GraphRuntimeCodegen auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); auto &device_type = storage_device_map_[expr][1]; - auto call_dev_type = device_type[0]->value; //-> int to string + auto call_dev_type = device_type[0]->value; Target target; if (targets_.size() == 1) { // homogeneous execution. @@ -415,22 +413,17 @@ class GraphRuntimeCodegen } } else { // heterogeneous execution. - const auto call_dev_key = std::to_string(call_dev_type); std::string call_dev_name; if (call_dev_type == 0) { call_dev_name = "llvm"; } else { call_dev_name = runtime::DeviceName(call_dev_type); } - if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) { + if (targets_.count(call_dev_type) == 0) { LOG(FATAL) << "No target is provided for device " << call_dev_name; } - if (targets_.count(call_dev_key)) { - target = targets_[call_dev_key]; - } else { - target = targets_[call_dev_name]; - } + target = targets_[call_dev_type]; } CCacheKey key = (*pf0)(func, target); CachedFunc lowerd_func = (*pf1)(compile_engine_, key); @@ -604,30 +597,21 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { virtual PackedFunc GetFunction(const std::string& name, const std::shared_ptr& sptr_to_self) { if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; - void* mod = args[0]; - auto& sptr = args[1].node_sptr(); - auto* node = static_cast(sptr.get()); - auto& tmp_targets = node->data; - std::unordered_map targets; - for (size_t i = 0; i < tmp_targets.size(); i += 2) { - std::string key; - auto sk = Expr(tmp_targets[i]).as(); - auto ik = Expr(tmp_targets[i]).as(); - if (sk) { - key = sk->value; - } - if (ik) { - key = std::to_string(ik->value); - } - auto v = Expr(tmp_targets[i + 1]).as(); - targets[key] = v->value; - } - codegen_ = std::make_shared( - reinterpret_cast(mod), targets); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2) + << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + CHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = std::make_shared( + reinterpret_cast(mod), targets); + }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d700c2036e21..1cc81d5174a5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(Constructor constructor, - tvm::Array fields) { +ConstructorValue ConstructorValueNode::make(int tag, + tvm::Array fields, + Constructor constructor) { NodePtr n = make_node(); - n->constructor = constructor; + n->tag = tag; n->fields = fields; + n->constructor = constructor; return ConstructorValue(n); } @@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorValueNode* node, tvm::IRPrinter* p) { - p->stream << "ConstructorValueNode(" << node->constructor + p->stream << "ConstructorValueNode(" << node->tag << "," << node->fields << ")"; }); @@ -448,7 +450,7 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueNode::make(GetRef(con), args); + return ConstructorValueNode::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); @@ -544,9 +546,8 @@ class Interpreter : const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); - CHECK_NE(cvn->constructor->tag, -1); - if (op->constructor->tag == cvn->constructor->tag) { - // todo(M.K.): should use ptr equality but it is broken + CHECK_NE(cvn->tag, -1); + if (op->constructor->tag == cvn->tag) { CHECK_EQ(op->patterns.size(), cvn->fields.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { if (!VisitPattern(op->patterns[i], cvn->fields[i])) { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 97f03c629cb7..dae90aad17d9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -38,15 +38,22 @@ namespace tvm { namespace relay { + +namespace transform { + +Pass LambdaLift(); +Pass InlinePrimitives(); + +} // namespace transform + namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; +using namespace relay::transform; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); -Module LambdaLift(const Module& module); -Module InlinePrimitives(const Module& module); template using NodeMap = std::unordered_map; @@ -73,6 +80,8 @@ struct VMCompilerContext { ConstTensorShapeMap const_tensor_shape_map; // List of lowered functions std::vector lowered_funcs; + // The functions that have been lowered. + std::unordered_map seen_funcs; }; // Compute the constant pool, i.e a mapping from Constant node to constant index. @@ -94,13 +103,6 @@ struct ConstantPool : ExprVisitor { } } - void AddConstantTensorShape(TensorType expr, NDArray value) { - auto it = this->const_tensor_shape_map.find(expr); - if (it == this->const_tensor_shape_map.end()) { - this->const_tensor_shape_map.insert({expr, std::make_pair(index++, value)}); - } - } - void VisitExpr_(const ConstantNode* const_node) { auto konst = GetRef(const_node); auto it = this->const_map.find(konst); @@ -108,48 +110,6 @@ struct ConstantPool : ExprVisitor { this->const_map.insert({konst, index++}); } } - - NDArray GetTensorConstant(const TensorTypeNode* ttype) { - std::vector shapes; - for (auto sh : ttype->shape) { - shapes.push_back(Downcast(sh)->value); - } - int64_t s = shapes.size(); - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - auto shape_tensor = NDArray::Empty({s}, Type2TVMType(Int(64)), cpu_ctx); - int64_t* dims = static_cast(shape_tensor->data); - for (size_t i = 0; i < shapes.size(); ++i) { - dims[i] = shapes[i]; - } - return shape_tensor; - } - - void VisitExpr_(const CallNode* call_node) { - for (auto arg : call_node->args) { - this->VisitExpr(arg); - } - - Expr op = call_node->op; - auto func_node = op.as(); - if (func_node) { - auto ret_type = call_node->checked_type(); - if (const TensorTypeNode* ttype = ret_type.as()) { - auto shape = GetTensorConstant(ttype); - auto tensor_type = GetRef(ttype); - AddConstantTensorShape(tensor_type, shape); - } else if (const TupleTypeNode* ttype = ret_type.as()) { - for (size_t i = 0; i < ttype->fields.size(); ++i) { - auto f = ttype->fields[i]; - auto f_type = f.as(); - auto shape = GetTensorConstant(f_type); - auto tensor_type = GetRef(f_type); - AddConstantTensorShape(tensor_type, shape); - } - } - } - } }; std::tuple LayoutConstantPool(const Module& module) { @@ -177,9 +137,6 @@ struct VMCompiler : ExprFunctor { size_t registers_num; CompileEngine engine; - /*! \brief The functions that have been lowered. */ - std::unordered_map seen_funcs; - /*! \brief Global shared meta data */ VMCompilerContext* context; @@ -200,6 +157,7 @@ struct VMCompiler : ExprFunctor { switch (instr.op) { case Opcode::AllocDatatype: case Opcode::AllocTensor: + case Opcode::AllocTensorReg: case Opcode::GetField: case Opcode::LoadConst: case Opcode::Select: @@ -254,13 +212,13 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const MatchNode* match_node) { auto match = GetRef(match_node); LOG(FATAL) << "translation of match nodes to the VM is" - << "currently unsupported" << std::endl; + << "currently unsupported"; } void VisitExpr_(const LetNode* let_node) { - DLOG(INFO) << let_node->value << std::endl; + DLOG(INFO) << let_node->value; this->VisitExpr(let_node->value); - DLOG(INFO) << this->last_register << std::endl; + DLOG(INFO) << this->last_register; var_register_map.insert({let_node->var, this->last_register}); this->VisitExpr(let_node->body); } @@ -273,7 +231,12 @@ struct VMCompiler : ExprFunctor { } void VisitExpr_(const GlobalVarNode* gvar) { - LOG(FATAL) << "Global variables should only appear in the call position"; + auto var = GetRef(gvar); + auto func = this->context->module->Lookup(var); + auto it = this->context->global_map.find(var); + CHECK(it != this->context->global_map.end()); + // Allocate closure with zero free vars + Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); } void VisitExpr_(const IfNode* if_node) { @@ -320,29 +283,51 @@ struct VMCompiler : ExprFunctor { } Instruction AllocTensorFromType(const TensorTypeNode* ttype) { - DataType dtype = ttype->dtype; - TVMType dltype = Type2TVMType(dtype); - + TVMType dltype = Type2TVMType(ttype->dtype); auto tensor_type = GetRef(ttype); - auto it = this->context->const_tensor_shape_map.find(tensor_type); - if (it == this->context->const_tensor_shape_map.end()) { - DLOG(INFO) << "Can not find constant shape for " << tensor_type; - } else { - Emit(Instruction::LoadConst(it->second.first, NewRegister())); + std::vector shape; + for (auto dim : tensor_type->shape) { + shape.push_back(Downcast(dim)->value); } - - return Instruction::AllocTensor(last_register, dltype, NewRegister()); + return Instruction::AllocTensor(shape, dltype, NewRegister()); } - void EmitInvokePrimitive(const Function& func, std::vector args_registers, + void EmitInvokePrimitive(const Function& func, + const std::vector& args_registers, const Type& ret_type) { + std::vector unpacked_arg_regs; std::vector allocs; - size_t return_num = 0; + + // Arity calculation must flatten tuples. + size_t arity = 0; + CHECK_EQ(func->params.size(), args_registers.size()); + for (size_t i = 0; i < func->params.size(); i++) { + auto ty = func->params[i]->checked_type(); + if (ty.as()) { + unpacked_arg_regs.push_back(args_registers[i]); + arity += 1; + } else if (auto tuple_ty = ty.as()) { + for (size_t f = 0; f < tuple_ty->fields.size(); f++) { + const auto& field = tuple_ty->fields[f]; + CHECK(field.as()) + << "only supports non-nested tuples currently " + << "found " << field; + auto dst = NewRegister(); + Emit(Instruction::GetField(args_registers[i], f, dst)); + unpacked_arg_regs.push_back(dst); + } + arity += tuple_ty->fields.size(); + } else { + LOG(FATAL) << "unsupported parameter type " << ty; + } + } + + size_t return_val_count = 0; if (const TensorTypeNode* ttype = ret_type.as()) { // Allocate space for the return tensor. auto alloc = AllocTensorFromType(ttype); allocs.push_back(alloc); - return_num = 1; + return_val_count = 1; } else if (const TupleTypeNode* ttype = ret_type.as()) { std::vector fields_registers; @@ -352,43 +337,42 @@ struct VMCompiler : ExprFunctor { allocs.push_back(AllocTensorFromType(f_type)); fields_registers.push_back(allocs.back().dst); } - return_num = ttype->fields.size(); + return_val_count = ttype->fields.size(); } else { LOG(FATAL) << "Unsupported return value type"; } + arity += return_val_count; for (auto& alloc : allocs) { Emit(alloc); - args_registers.push_back(alloc.dst); + unpacked_arg_regs.push_back(alloc.dst); } // Next generate the invoke instruction. CHECK(func->IsPrimitive()); - auto target = Target::create("llvm"); + auto target = Target::Create("llvm"); auto key = CCacheKeyNode::make(func, target); auto cfunc = engine->Lower(key); // TODO(jroesch): support lowered funcs for multiple targets CHECK_EQ(cfunc->funcs.size(), 1); auto op_index = -1; - if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) { + if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) { op_index = this->context->lowered_funcs.size(); this->context->lowered_funcs.push_back(cfunc->funcs[0]); - seen_funcs[cfunc->funcs[0]] = op_index; + this->context->seen_funcs[cfunc->funcs[0]] = op_index; } else { - op_index = seen_funcs[cfunc->funcs[0]]; + op_index = this->context->seen_funcs[cfunc->funcs[0]]; } - // If Tensor, 1 - // If Tuple, size of tuple - size_t arity = func->params.size() + return_num; - Emit(Instruction::InvokePacked(op_index, arity, return_num, args_registers)); - if (return_num > 1) { + Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); + + if (return_val_count > 1) { // return value is a tuple, we need to create a tuple std::vector fields_registers; - for (size_t i = func->params.size(); i < arity; ++i) { - fields_registers.push_back(args_registers[i]); + for (size_t i = arity - return_val_count; i < arity; ++i) { + fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_num, fields_registers, NewRegister())); + Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister())); } } @@ -396,7 +380,6 @@ struct VMCompiler : ExprFunctor { std::vector args_registers; for (auto arg : call_node->args) { - CHECK(arg.as()) << "found: " << AsText(arg, false) << std::endl << arg; this->VisitExpr(arg); args_registers.push_back(last_register); } @@ -416,18 +399,14 @@ struct VMCompiler : ExprFunctor { auto func = this->context->module->Lookup(global); if (IsClosure(func)) { auto arity = func->params.size(); - std::vector free_var_registers; - for (size_t i = 0; i < arity; ++i) { - free_var_registers.push_back(var_register_map.at(func->params[i])); - } - Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister())); + Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); } else { Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); } } else if (auto constructor_node = op.as()) { auto constructor = GetRef(constructor_node); - auto tag = GetConstructorTag(constructor); - Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister())); + Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, + NewRegister())); } else if (auto var_node = op.as()) { VisitExpr(GetRef(var_node)); Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister())); @@ -436,18 +415,6 @@ struct VMCompiler : ExprFunctor { } } - size_t GetConstructorTag(tvm::relay::Constructor constructor) { - auto it = this->context->tag_map.find(constructor); - if (it != this->context->tag_map.end()) { - return it->second; - } else { - auto tag = this->context->tag_map.size(); - this->context->tag_map[constructor] = tag; - this->context->tag_index_map[tag] = constructor; - return tag; - } - } - void VisitExpr_(const FunctionNode* func_node) { if (!func_node->IsPrimitive()) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl @@ -502,7 +469,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, runtime::Module mod; if (lowered_funcs.size() > 0) { // TODO(@jroesch): we need to read target from build config - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); if (const auto* f = runtime::Registry::Get("relay.backend.build")) { mod = (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target); } else { @@ -516,7 +483,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, } VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { - DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl; + DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false); size_t params = func->params.size(); VMCompiler compiler(context); compiler.Compile(func); @@ -534,10 +501,13 @@ VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const F } Module OptimizeModule(const Module& mod) { - ToANormalForm(mod->entry_func, mod); - InlinePrimitives(mod); - LambdaLift(mod); - return InlinePrimitives(mod); + transform::Sequential seq({transform::ToANormalForm(), + transform::InlinePrimitives(), + transform::LambdaLift(), + transform::InlinePrimitives()}); + auto pass_ctx = transform::PassContext::Create(); + tvm::With ctx(pass_ctx); + return seq(mod); } void PopulateGlobalMap(GlobalMap* global_map, const Module& mod) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index b033a37e42b8..1e561f8a8214 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -37,6 +37,21 @@ namespace tvm { namespace relay { namespace vm { +// TODO(@jroesch): write verifier + +/* This pass will eliminate primitives which have been lifted by the ANF + * transform inlining them directly into call sites. + * + * This makes VM related code generation easier as the call target is always + * a primitive function. + * + * let prim = fn(...) { ... }; + * prim(...) + * + * will become: + * + * (fn(...) { ... })(...) + */ struct PrimitiveInliner : ExprMutator { Module module_; std::unordered_map var_map; @@ -92,55 +107,46 @@ struct PrimitiveInliner : ExprMutator { } } - Function Inline(const Function& func) { - DLOG(INFO) << "Before inlining primitives: " << std::endl - << "func= " << AsText(func, false) << std::endl; - - auto inlined = FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); - - inlined = Downcast(DeadCodeElimination(inlined)); - - DLOG(INFO) << "After inlining primitives" << std::endl - << "after_func= " << AsText(inlined, false) << std::endl; - return inlined; + Module Inline() { + auto gvar_funcs = module_->functions; + for (auto pair : gvar_funcs) { + auto global = pair.first; + auto func = pair.second; + DLOG(INFO) << "Before inlining primitives: " << global + << std::endl << AsText(func, false); + + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(global, func, true); + + DLOG(INFO) << "After inlining primitives: " << global + << std::endl << AsText(func, false); + } + return module_; } }; -// TODO(@jroesch): write verifier - -/* This pass will eliminate primitives which have been lifted by the ANF - * transform inlining them directly into call sites. - * - * This makes VM related code generation easier as the call target is always - * a primitive function. - * - * let prim = fn(...) { ... }; - * prim(...) - * - * will become: - * - * (fn(...) { ... })(...) - */ -Module InlinePrimitives(const Module& module) { - PrimitiveInliner inliner(module); +} // namespace vm - tvm::Map updates; +namespace transform { - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, inliner.Inline(func)); - } +Pass InlinePrimitives() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::PrimitiveInliner(m).Inline(); + }; + auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); + // Eliminate dead code for each function after inlining. + return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.InlinePrimitives") +.set_body_typed(InlinePrimitives); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 13d8112440fb..668c024a8d55 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -54,9 +55,14 @@ Function MarkClosure(const Function& func) { return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); } +/* The goal of this class is to lift out any nested functions into top-level + * functions. + * + * We will lift a function out into a global which takes the set of the free + * vars and then return the new created function. + */ struct LambdaLifter : ExprMutator { Module module_; - std::vector> lifted_; explicit LambdaLifter(const Module& module) : module_(module) {} Expr VisitExpr_(const FunctionNode* func_node) final { @@ -71,8 +77,7 @@ struct LambdaLifter : ExprMutator { auto free_type_vars = FreeTypeVars(func, module_); auto body = Downcast(ExprMutator::VisitExpr_(func_node)); - // When performing this optimization there are two - // cases. + // When performing this optimization there are two cases. // // The first case in which we have no free variables // we can just lift the function into the global @@ -80,7 +85,7 @@ struct LambdaLifter : ExprMutator { // // // The second case requires that we generate a special - // function with makes a distinction between allocating + // function which makes a distinction between allocating // a closure, and then the code for the closure. // // We represent a closure allocation by lifting the @@ -92,7 +97,7 @@ struct LambdaLifter : ExprMutator { // function marked as a closure is used to emit allocation // code for the closure's environment. // - // The "inner" function is should be used to generate the + // The "inner" function should be used to generate the // code for the closure. Function lifted_func; if (free_vars.size() == 0) { @@ -107,16 +112,16 @@ struct LambdaLifter : ExprMutator { CHECK(lifted_func.defined()); auto name = GenerateName(lifted_func); - auto global = this->module_->GetGlobalVar(name); + auto global = GlobalVarNode::make(name); - lifted_.push_back({global, lifted_func}); + // Add the lifted function to the module. + module_->Add(global, lifted_func); if (free_vars.size() == 0) { return std::move(global); } else { - // If we need to allocate a closure - // we pass the variables in its environment - // here. + // If we need to allocate a closure, + // we pass the variables in its environment here. Array fvs; for (auto fv : free_vars) { fvs.push_back(fv); @@ -125,42 +130,39 @@ struct LambdaLifter : ExprMutator { } } - Function Lift(const Function& func) { - DLOG(INFO) << "Lifting: " << AsText(func, false) << std::endl; - return FunctionNode::make(func->params, VisitExpr(func->body), func->ret_type, - func->type_params, func->attrs); + Module Lift() { + // There is an ordering bug here. + auto glob_funcs = module_->functions; + for (auto pair : glob_funcs) { + auto func = pair.second; + DLOG(INFO) << "Lifting " << AsText(func, false); + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Add(pair.first, func, true); + } + return module_; } }; -/* The goal of this pass is to lift out any nested functions into top-level - * functions. - * - * We will lift the functions out into globals which take the set of the free vars - * and then return a function whcih has b - */ -Module LambdaLift(const Module& module) { - LambdaLifter lifter(module); - - tvm::Map updates; +} // namespace vm - // There is an ordering bug here. - for (auto pair : module->functions) { - auto global = pair.first; - auto func = pair.second; - updates.Set(global, lifter.Lift(func)); - } +namespace transform { - for (auto i = lifter.lifted_.begin(); i != lifter.lifted_.end(); i++) { - module->Add(i->first, i->second); - } +Pass LambdaLift() { + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return relay::vm::LambdaLifter(m).Lift(); + }; + return CreateModulePass(pass_func, 1, "LambdaLift", {}); +} - for (auto pair : updates) { - module->Add(pair.first, pair.second, true); - } +TVM_REGISTER_API("relay._transform.LambdaLift") +.set_body_typed(LambdaLift); - return module; -} +} // namespace transform -} // namespace vm } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc index 34d067b9c68c..cf0b952005fc 100644 --- a/src/relay/backend/vm/vm.cc +++ b/src/relay/backend/vm/vm.cc @@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector ctxs, return res; } -Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) { - CHECK(module.defined() && type.defined()); +Value VMToValue(const relay::Module& module, Object obj) { + CHECK(module.defined()); switch (obj->tag) { case ObjectTag::kTensor: { - CHECK(type.as()) << "VM internal error: return value must be a tensor"; return TensorValueNode::make(ToNDArray(obj)); } case ObjectTag::kDatatype: { - // const auto* tuple_type - // const auto& data_type = obj.AsDatatype(); + const auto& data_type = obj.AsDatatype(); - // tvm::Array fields; - // for (size_t i = 0; i < data_type->fields.size(); ++i) { - // fields.push_back(VMToValue(tag_index_map, data_type->fields[i])); - // } + tvm::Array fields; + for (size_t i = 0; i < data_type->fields.size(); ++i) { + fields.push_back(VMToValue(module, data_type->fields[i])); + } - // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields); - LOG(FATAL) << "fix me"; + return ConstructorValueNode::make(data_type->tag, fields); } default: LOG(FATAL) << "unsupported return value of type: " << obj->tag; @@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue LOG(FATAL) << "expected function or module"; } - auto return_type = module->Lookup(module->entry_func)->ret_type; - std::vector vm_args; for (auto i = 3; i < args.size(); i++) { Object obj = args[i]; @@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue auto result = EvaluateModule(module, {ctx}, vm_args); DLOG(INFO) << "Evaluate VM returning: result=" << result->tag; - *ret = VMToValue(module, return_type, result); + *ret = VMToValue(module, result); }); } // namespace vm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 64706933fde3..e0ec10a87061 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ", " << node->args << ", " + << node->attrs << ", " << node->type_args << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_API("relay._expr.TempExprRealize") .set_body_typed([](TempExpr temp) { - return temp->Realize(); + return temp->Realize(); }); } // namespace relay diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index c56c4ce17067..c57475476e58 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -219,6 +219,9 @@ class RelayHashHandler: size_t BindVar(const NodeRef& var) { size_t hash = std::hash()(var_counter++); CHECK_EQ(hash_map_.count(var), 0); + if (auto var_node = var.as()) { + hash = Combine(hash, TypeHash(var_node->type_annotation)); + } hash_map_[var] = hash; const auto* ty_param = var.as(); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 6b5fee82af89..58f614a3cc77 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map global_funcs, return Module(n); } -GlobalVar ModuleNode::GetGlobalVar(const std::string& name) { +GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); - if (it == global_var_map_.end()) { - auto gvar = GlobalVarNode::make(name); - global_var_map_.Set(name, gvar); - return gvar; - } else { - return (*it).second; - } + CHECK(it != global_var_map_.end()) + << "Cannot find global var " << name << " in the Module"; + return (*it).second; } void ModuleNode::AddUnchecked(const GlobalVar& var, @@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } -GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { +GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) << "Cannot find global type var " << name << " in the Module"; @@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) { gvar_node->data.erase(var->name_hint); } -Function ModuleNode::Lookup(const GlobalVar& var) { +Function ModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -Function ModuleNode::Lookup(const std::string& name) { +Function ModuleNode::Lookup(const std::string& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { +TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) << "There is no definition of " << var->var->name_hint; return (*it).second; } -TypeData ModuleNode::LookupDef(const std::string& name) { +TypeData ModuleNode::LookupDef(const std::string& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupDef(id); } diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index 4a23f59f9637..b4303e7ac6b1 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -54,8 +54,8 @@ struct OpManager { std::vector frontend_funcs; // get singleton of the op manager static OpManager* Global() { - static OpManager inst; - return &inst; + static OpManager* inst = new OpManager(); + return inst; } }; diff --git a/src/relay/op/algorithm/sort.cc b/src/relay/op/algorithm/argsort.cc similarity index 94% rename from src/relay/op/algorithm/sort.cc rename to src/relay/op/algorithm/argsort.cc index 5777b79699b1..31aa88808a23 100644 --- a/src/relay/op/algorithm/sort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -18,9 +18,9 @@ */ /*! - * Copyright (c) 2018 by Contributors - * \file nms.cc - * \brief Non-maximum suppression operators + * Copyright (c) 2019 by Contributors + * \file argsort.cc + * \brief Argsort operators */ #include #include @@ -44,7 +44,6 @@ bool ArgsortRel(const Array& types, << types[0]; return false; } - CHECK_EQ(param->dtype, Float(32)); reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype)); return true; } @@ -74,5 +73,6 @@ input array along the given axis. .add_argument("data", "Tensor", "Input data.") .set_support_level(6) .add_type_rel("Argsort", ArgsortRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc new file mode 100644 index 000000000000..c88e2c3ea007 --- /dev/null +++ b/src/relay/op/algorithm/topk.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file topk.cc + * \brief TopK operators + */ +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(TopKAttrs); + +bool TopKRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + const TopKAttrs* param = attrs.as(); + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data); + int ndim = data->shape.size(); + int axis = param->axis; + if (axis < 0) { + axis += ndim; + } + CHECK(axis >= 0 && axis < ndim); + Array out_shape; + for (int i = 0; i < ndim; ++i) { + if (i != axis || param->k < 1) { + out_shape.push_back(data->shape[i]); + } else { + out_shape.push_back(param->k); + } + } + auto values_ty = TensorTypeNode::make(out_shape, data->dtype); + auto indices_ty = TensorTypeNode::make(out_shape, param->dtype); + if (param->ret_type == "both") { + reporter->Assign(types[1], TupleTypeNode::make({values_ty, indices_ty})); + } else if (param->ret_type == "values") { + reporter->Assign(types[1], values_ty); + } else if (param->ret_type == "indices") { + reporter->Assign(types[1], indices_ty); + } else { + LOG(FATAL) << "Unsupported ret type: " << param->ret_type; + } + return true; +} + +Expr MakeTopK(Expr data, + int k, + int axis, + std::string ret_type, + bool is_ascend, + DataType dtype) { + auto attrs = make_node(); + attrs->k = k; + attrs->axis = axis; + attrs->ret_type = ret_type; + attrs->is_ascend = is_ascend; + attrs->dtype = dtype; + static const Op& op = Op::Get("topk"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op._make.topk") +.set_body_typed(MakeTopK); + +RELAY_REGISTER_OP("topk") +.describe(R"doc(Get the top k elements in an input tensor along the given axis. +)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.TopKAttrs") +.add_argument("data", "Tensor", "Input data.") +.set_support_level(6) +.add_type_rel("TopK", TopKRel); + +} // namespace relay +} // namespace tvm + diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 4dd763b45654..44c9f89aa9e7 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -70,7 +70,8 @@ bool Pool2DRel(const Array& types, CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) return false; + const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index a4ebd1e8d050..647e4d0f4f90 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -355,6 +355,43 @@ Example:: .set_attr("TOpPattern", kCommReduce); +Array AllCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::all); +} + + +RELAY_REGISTER_REDUCE_OP("all") +.describe(R"code(Computes the logical AND of boolean array elements over given axes. + +Example:: + + data = [[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]] + + all(data, axis=1) + [[False, True, False], + [False, False, False]] + + all(data, axis=0) + [[ True, False, False], + [ True, True, False], + [False, True, False]] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", AllCompute) +.set_attr("TOpPattern", kCommReduce); + + Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 16d09c46dfa2..5b147a489b44 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -108,8 +108,8 @@ bool BroadcastRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); - DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] - << ",Out:" << types[2] << std::endl; + // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] + // << ",Out:" << types[2] << std::endl; if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); @@ -126,8 +126,8 @@ bool BroadcastCompRel(const Array& types, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); - DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] - << ",Out:" << types[2] << std::endl; + // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] + // << ",Out:" << types[2] << std::endl; if (auto t0 = ToTensorType(types[0])) { if (auto t1 = ToTensorType(types[1])) { CHECK_EQ(t0->dtype, t1->dtype); diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 2e5661cdc4dc..c0160e7d7128 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -50,9 +50,13 @@ bool GetValidCountRel(const Array& types, } Expr MakeGetValidCounts(Expr data, - double score_threshold) { + double score_threshold, + int id_index, + int score_index) { auto attrs = make_node(); attrs->score_threshold = score_threshold; + attrs->id_index = id_index; + attrs->score_index = score_index; static const Op& op = Op::Get("vision.get_valid_counts"); return CallNode::make(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index f51c201d0b2a..d623393049a6 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -338,17 +339,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body([](TVMArgs args, TVMRetValue *ret) { +Expr AlterOpLayout(const Expr& expr) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - *ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext); -}); + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); +} + +TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") +.set_body_typed(AlterOpLayout); } // namespace alter_op_layout +namespace transform { + +Pass AlterOpLayout() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; + return CreateFunctionPass(pass_func, 3, "AlterOpLayout", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.AlterOpLayout") +.set_body_typed(AlterOpLayout); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc new file mode 100644 index 000000000000..99f4a7f44e7e --- /dev/null +++ b/src/relay/pass/canonicalize_cast.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file canonicalize_cast.cc + * \brief Canonicalize cast expressions to make operator fusion more efficient. + */ +#include +#include +#include +#include +#include "pattern_util.h" +#include "pass_util.h" + +namespace tvm { +namespace relay { + +// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a +// copy of it in each branch such that after fusion the previous function have output with fewer +// bits. +// +// Consider the following example: +// \code +// def @main(x: int8) { +// %1 = cast(%x, f32) +// %2 = exp(%1) +// %3 = log(%1) +// (%3, 4) +// } +// \endcode +// +// We would like to prevent sharing of the cast expression such that operator fusion can produce +// more efficient result as below. +// \code +// def @main(x: int8) { +// %1 = fn (%p1: i8) { +// exp(cast(%p1, f32) +// } +// %3 = %1(%x) +// %2 = fn (%p1: i8) { +// log(cast(%p1, f32) +// } +// %4 = %2(%x) +// (%3, 4) +// } +// \endcode +class CastCanonicalizer : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* call) { + static auto fpattern = Op::GetAttr("TOpPattern"); + + if (const OpNode* opnode = call->op.as()) { + auto pattern = fpattern[GetRef(opnode)]; + if (pattern <= kBroadcast) { + Array call_args = call->args; + bool unchanged = true; + for (size_t i = 0; i < call_args.size(); ++i) { + Expr arg = call_args[i]; + Expr new_arg = GetNewCallArg(arg); + if (!arg.same_as(new_arg)) { + call_args.Set(i, new_arg); + unchanged = false; + } + } + if (unchanged) { + return GetRef(call); + } + return CallNode::make(call->op, call_args, call->attrs, call->type_args); + } + } + + Expr new_expr = ExprMutator::VisitExpr_(call); + return new_expr; + } + + private: + std::unordered_map ref_counter_; + + Expr GetNewCallArg(const Expr& e) { + // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor + + static auto& cast = Op::Get("cast"); + Expr new_expr = this->VisitExpr(e); + + if (const CallNode* call = e.as()) { + if (call->op.same_as(cast)) { + auto attrs = call->attrs.as(); + const auto* from_type = call->args[0]->type_as(); + CHECK(from_type); + + if (from_type->dtype.bits() < attrs->dtype.bits()) { + if (++ref_counter_[call] > 1) { + const CallNode* new_call = new_expr.as(); + CHECK(new_call); + CHECK(new_call->op.same_as(cast)); + return CallNode::make(new_call->op, new_call->args, new_call->attrs, + new_call->type_args); + } + } + } + } + return new_expr; + } +}; + +Expr CanonicalizeCast(const Expr& e) { + return CastCanonicalizer().Mutate(e); +} + +namespace transform { + +Pass CanonicalizeCast() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeCast") +.set_body_typed(CanonicalizeCast); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index 9a4602750195..ff9e2304a3bc 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "pattern_util.h" namespace tvm { @@ -63,5 +64,21 @@ Expr CanonicalizeOps(const Expr& e) { TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") .set_body_typed(CanonicalizeOps); +namespace transform { + +Pass CanonicalizeOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CanonicalizeOps") +.set_body_typed(CanonicalizeOps); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 7e76322d5a2a..c95c1ddf8e16 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include "./expr_subst.h" @@ -357,5 +358,21 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body_typed(CombineParallelConv2D); +namespace transform { + +Pass CombineParallelConv2D(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelConv2D") +.set_body_typed(CombineParallelConv2D); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 533c21429995..7e186f80df92 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -38,10 +38,10 @@ namespace relay { // calculate the dependency graph from expression class CalcDep : private ExprVisitor { public: - static Expr Eliminate(const Expr& e) { + static Expr Eliminate(const Expr& e, bool inline_once) { CalcDep cd; cd.Calculate(e); - Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); + Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_, inline_once); return el(e); } @@ -117,15 +117,23 @@ class CalcDep : private ExprVisitor { VarMap expr_map_; VarMap use_map_; VarSet letrec_set_; + bool inline_once_; explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, - const VarSet& letrec_set) : - expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } + const VarSet& letrec_set, + bool inline_once) : + expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set), inline_once_(inline_once) { } friend CalcDep; bool HasLet(const Var& v) { - // TODO(@jroesch): MK fix me - return (use_map_[v] > 0 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + switch (use_map_[v]) { + case 0: + return false; + case 1: + return letrec_set_.count(v) > 0 || !inline_once_; + default: + return true; + } } Expr VisitExpr_(const VarNode* op) final { @@ -144,12 +152,27 @@ class CalcDep : private ExprVisitor { }; }; -Expr DeadCodeElimination(const Expr& e) { - return CalcDep::Eliminate(e); +Expr DeadCodeElimination(const Expr& e, bool inline_once) { + return CalcDep::Eliminate(e, inline_once); } TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") .set_body_typed(DeadCodeElimination); +namespace transform { + +Pass DeadCodeElimination(bool inline_once) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(DeadCodeElimination(f, inline_once)); + }; + return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); +} + +TVM_REGISTER_API("relay._transform.DeadCodeElimination") +.set_body_typed(DeadCodeElimination); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 0139cc912849..8eeb493f1feb 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -67,6 +68,7 @@ class ValidateAnnotation : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { + ExprVisitor::VisitExpr_(call_node); if (IsOnDeviceNode(call_node)) { int device_type = GetDeviceId(call_node); if (annotation_map_.count(call_node)) { @@ -85,7 +87,14 @@ class ValidateAnnotation : private ExprVisitor { annotation_map_.insert({node, GetDeviceId(call_node)}); } } - ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const TupleGetItemNode* get_elem) final { + ExprVisitor::VisitExpr_(get_elem); + const auto* tn = get_elem->tuple.operator->(); + if (annotation_map_.count(tn)) { + annotation_map_.insert({get_elem, annotation_map_.at(tn)}); + } } /* @@ -176,7 +185,11 @@ class RewriteAnnotation : public ExprMutator { } Expr VisitExpr_(const CallNode* call_node) final { - if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) { + if (IsOnDeviceNode(call_node)) { + return this->VisitExpr(call_node->args[0]); + } + + if (IsDeviceCopyNode(call_node)) { return ExprMutator::VisitExpr_(call_node); } @@ -248,7 +261,9 @@ class RewriteAnnotation : public ExprMutator { if (src->is_type() || src->is_type()) { return annotation_map_.at(dst) != fallback_device_; } else { - return false; + // There shouldn't be any copy nodes between var/constant and another + // expression. + return !(src->is_type() || src->is_type()); } } else { return false; @@ -358,6 +373,9 @@ class DeviceInfo { public: void Visit(const Expr& expr) { if (const auto* fn = expr.as()) { + for (const auto& param : fn->params) { + this->VisitExpr(param); + } this->VisitExpr(fn->body); } else { this->VisitExpr(expr); @@ -402,7 +420,7 @@ class DeviceInfo { } void VisitExpr_(const VarNode* vn) final { - post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); + post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); } void VisitExpr_(const LetNode* ln) final { @@ -485,7 +503,52 @@ class DeviceInfo { Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { RewriteAnnotation rewrote = RewriteAnnotation(); - return rewrote.Rewrite(expr, fallback_device); + Expr new_expr = rewrote.Rewrite(expr, fallback_device); + + // Remove OnDevice operators. Note that these operators are only present at the + // leaves after annotation. Therefore, we can simply reconstruct the + // Function/Expr by removing them directly. + if (const FunctionNode* fn = new_expr.as()) { + auto params = fn->params; + auto body = fn->body; + std::vector new_body; + if (const TupleNode* tuple = body.as()) { + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_body.push_back(field); + } + } + CHECK_GT(new_body.size(), 0U); + if (new_body.size() == 1) { + return FunctionNode::make(params, new_body[0], Type(nullptr), + fn->type_params, fn->attrs); + } else if (tuple->fields.size() == new_body.size()) { + return new_expr; + } else { + Tuple tuple_body = TupleNode::make(new_body); + return FunctionNode::make(params, tuple_body, Type(nullptr), + fn->type_params, fn->attrs); + } + } else { + return new_expr; + } + } else if (const TupleNode* tuple = new_expr.as()) { + std::vector new_fields; + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_fields.push_back(field); + } + } + CHECK_GT(new_fields.size(), 0U); + if (tuple->fields.size() == new_fields.size()) { + return new_fields.size() == 1 ? new_fields[0] : new_expr; + } else { + return new_fields.size() == 1 ? new_fields[0] + : TupleNode::make(new_fields); + } + } else { + return new_expr; + } } Map CollectDeviceInfo(const Expr& expr) { @@ -505,6 +568,21 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); +namespace transform { + +Pass RewriteAnnotatedOps(int fallback_device) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(RewriteAnnotatedOps(f, fallback_device)); + }; + return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.RewriteDeviceAnnotation") +.set_body_typed(RewriteAnnotatedOps); + +} // namespace transform + } // namespace relay } // namespace tvm - diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index f8432f671855..883681adcaf4 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -29,6 +29,7 @@ */ #include #include +#include #include #include "./pattern_util.h" @@ -87,5 +88,21 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") .set_body_typed(EliminateCommonSubexpr); +namespace transform { + +Pass EliminateCommonSubexpr(PackedFunc fskip) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; + return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr") +.set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 0193b9afc62e..3139d41d6393 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -67,5 +67,20 @@ Expr EtaExpand(const Expr& e, const Module& mod) { TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand); +namespace transform { + +Pass EtaExpand() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(EtaExpand(f, m)); + }; + return CreateFunctionPass(pass_func, 1, "EtaExpand", {}); +} + +TVM_REGISTER_API("relay._transform.EtaExpand") +.set_body_typed(EtaExpand); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 45aa449e72ab..815407038b08 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -203,10 +204,10 @@ Expr FoldConstant(const Expr& expr) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return ConstantFolder(CreateInterpreter( Module(nullptr), ctx, target)).Mutate(expr); @@ -215,5 +216,20 @@ Expr FoldConstant(const Expr& expr) { TVM_REGISTER_API("relay._ir_pass.FoldConstant") .set_body_typed(FoldConstant); +namespace transform { + +Pass FoldConstant() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(FoldConstant(f)); + }; + return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); +} + +TVM_REGISTER_API("relay._transform.FoldConstant") +.set_body_typed(FoldConstant); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index c738e3e3b731..53089807ace5 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "pattern_util.h" #include "pass_util.h" @@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); -Expr ForwardFoldScaleAxis(Expr data) { +Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); auto fcontext = [&](const Call& call) -> NodeRef{ auto it = message.find(call.get()); @@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d") RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); -Expr BackwardFoldScaleAxis(Expr data) { +Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } @@ -950,5 +951,42 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") .set_body_typed(BackwardFoldScaleAxis); } // namespace fold_scale_axis + +namespace transform { + +Pass ForwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass BackwardFoldScaleAxis() { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast( + relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; + return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", + {ir::StringImm::make("InferType")}); +} + +Pass FoldScaleAxis() { + // FoldScaleAxis pass contains the following three passes. Therefore, we can + // register it as a sequential pass. + Pass pass = Sequential( + {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); + return pass; +} + +TVM_REGISTER_API("relay._transform.FoldScaleAxis") +.set_body_typed(FoldScaleAxis); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 88a2d669da9f..8ad61270e33a 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr, return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } +namespace transform { + +using std::function; + +Pass ForwardRewrite(const std::string& rewrite_map_attr_name, + function fcontext, + function fmulti_ref_trigger) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_map_attr_name, + fcontext, + fmulti_ref_trigger)); + }; + return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); +} + +Pass ForwardRewrite(const FForwardRewrite& rewrite_func, + function fcontext, + function fmulti_ref_trigger) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(ForwardRewrite(f, + rewrite_func, + fcontext, + fmulti_ref_trigger)); + }; + return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); +} + +} // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index d0d0cab22432..9f940e54953b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include "./pattern_util.h" #include "../../common/arena.h" @@ -964,5 +965,23 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { TVM_REGISTER_API("relay._ir_pass.FuseOps") .set_body_typed(FuseOps); + +namespace transform { + +Pass FuseOps(int fuse_opt_level) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + return Downcast(FuseOps(f, opt_level, m)); + }; + return CreateFunctionPass(pass_func, 1, "FuseOps", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.FuseOps") +.set_body_typed(FuseOps); + +} // namespace transform + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 5c5ea01ac2f3..91072b31a910 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -279,7 +279,7 @@ struct ReverseAD : ExprMutator { } std::vector orig_args; for (const auto& arg : args) { - orig_args.push_back(GetField(VisitExpr(arg), 0)); + orig_args.push_back(GetField(arg, 0)); } Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); Var orig_var = ll->Push(orig); diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index c9ee4eec0337..3d77fabe6fe9 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,9 @@ #include #include #include +#include #include +#include "pattern_util.h" namespace tvm { namespace relay { @@ -65,7 +67,7 @@ int64_t ConvMacCount(const Call& call_node) { } Array args = call_node->args; CHECK(args.size() == 2) - << "The number of input arguments of a CONV 2D node should be 2."; + << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; @@ -73,18 +75,21 @@ int64_t ConvMacCount(const Call& call_node) { int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); CHECK(C_ind != -1) - << "There is no input channel dimension."; + << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; CHECK(kernel_size.size() == 2) - << "The dimension of the kernel size in Conv 2D should be 2."; + << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; - int64_t count = input_channel * GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + CHECK_EQ(input_channel % conv_2d_attr->groups, 0) + << "The number of input channels is not divisble by groups."; + count *= input_channel/conv_2d_attr->groups; return count; } diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc new file mode 100644 index 000000000000..173d6eacf528 --- /dev/null +++ b/src/relay/pass/match_exhaustion.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file match_exhaustion.cc + * \brief Checking Relay match expression exhaustiveness. + * + * This file implements a function that checks whether a match + * expression is exhaustive, that is, whether a given match clause + * matches every possible case. This is important for ensuring + * code correctness, since hitting an unmatched case results in a + * dynamic error unless exhaustiveness is checked in advance. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Possible pattern match results */ +enum MatchResult : int { + kMatch = 0, // pattern matches + kClash = 1, // pattern conflicts + kUnspecified = 2, // ambiguous: candidate needs more constructors specified +}; + +class CandidateChecker : public PatternFunctor { + public: + explicit CandidateChecker() {} + + MatchResult Check(const Pattern& pat, const Pattern& candidate) { + return this->VisitPattern(pat, candidate); + } + + // for a constructor pattern, we must ensure that the candidate is + // a ConstructorPattern, that it has the same constructor, and + // that its fields match the subpatterns. + MatchResult VisitPattern_(const PatternConstructorNode* op, const Pattern& cand) override { + auto* ctor_cand = cand.as(); + // attempting to match non-constructor to constructor pattern: need to specify + if (ctor_cand == nullptr) { + return MatchResult::kUnspecified; + } + + // check that constructors match + if (!op->constructor.same_as(ctor_cand->constructor)) { + return MatchResult::kClash; + } + + // now check that subpatterns match + CHECK(op->patterns.size() == ctor_cand->patterns.size()); + bool unspecified = false; + for (size_t i = 0; i < op->patterns.size(); i++) { + MatchResult submatch = this->Check(op->patterns[i], ctor_cand->patterns[i]); + // if we have a clash anywhere, then we can return clash + if (submatch == MatchResult::kClash) { + return MatchResult::kClash; + } + if (submatch == MatchResult::kUnspecified) { + unspecified = true; + } + } + // only return unspecified if we have ruled out a clash + if (unspecified) { + return MatchResult::kUnspecified; + } + return MatchResult::kMatch; + } + + // wildcard and var patterns always match + MatchResult VisitPattern_(const PatternWildcardNode*, const Pattern&) override { + return MatchResult::kMatch; + } + + MatchResult VisitPattern_(const PatternVarNode*, const Pattern&) override { + return MatchResult::kMatch; + } +}; + +// Returns list of arrays corresponding to Cartesian product of input list +Array> CartesianProduct(Array> fields) { + CHECK_NE(fields.size(), 0); + Array field_vals = fields[fields.size() - 1]; + Array> ret; + + // base case: this is the last field left + if (fields.size() == 1) { + for (auto val : field_vals) { + ret.push_back(Array{val}); + } + return ret; + } + + // if we have more fields left, get the sub-candidates by getting + // their cartesian product and appending the elements here onto those + Array> remaining_fields; + for (size_t i = 0; i < fields.size() - 1; i++) { + remaining_fields.push_back(fields[i]); + } + Array> candidates = CartesianProduct(remaining_fields); + for (auto val : field_vals) { + for (auto candidate : candidates) { + candidate.push_back(val); + ret.push_back(candidate); + } + } + return ret; +} + +// Expands all wildcards in the candidate pattern once, using the pattern +// to decide which constructors to insert. Returns a list of all possible expansions. +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, + const Module& mod) { + auto ctor_cand = cand.as(); + PatternConstructor clause_ctor = Downcast(clause_pat); + auto gtv = Downcast(clause_ctor->constructor->belong_to); + + // for a wildcard node, create constructor nodes with wildcards for all args + if (!ctor_cand) { + TypeData td = mod->LookupDef(gtv); + // for each constructor add a candidate + Array ret; + for (auto constructor : td->constructors) { + Array args; + for (auto inp : constructor->inputs) { + args.push_back(PatternWildcardNode::make()); + } + ret.push_back(PatternConstructorNode::make(constructor, args)); + } + return ret; + } + + // for constructors, we will expand the wildcards in any field + // that is an ADT + Array> values_by_field; + for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { + auto* subpattern = clause_ctor->patterns[i].as(); + // for non-ADT fields, we can only have a wildcard for the value + if (!subpattern) { + values_by_field.push_back({PatternWildcardNode::make()}); + continue; + } + + // otherwise, recursively expand + values_by_field.push_back(ExpandWildcards(GetRef(subpattern), + ctor_cand->patterns[i], mod)); + } + + // generate new candidates using a cartesian product + auto all_subfields = CartesianProduct(values_by_field); + Array ret; + for (auto subfields : all_subfields) { + ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields)); + } + return ret; +} + +/*! + * \brief Finds cases that the match expression does not catch, if any. + * \return Returns a list of cases that are not handled by the match + * expression. + */ +Array UnmatchedCases(const Match& match, const Module& mod) { + /* algorithm: + * candidates = { Wildcard } + * while candidates not empty { + * cand = candidates.pop() + * for clause in clauses { + * if clause fails: next clause + * if clause matches candidate: next candidate + * if candidate is not specific enough: + * candidates += expand_possible_wildcards(cand) + * next candidate + * } + * failed_candidates += { cand } + * } + * return failed_candidates + */ + std::stack candidates; + candidates.push(PatternWildcardNode::make()); + CandidateChecker checker; + + Array failures; + + while (!candidates.empty()) { + Pattern cand = candidates.top(); + candidates.pop(); + + bool failure = true; + for (auto clause : match->clauses) { + // if the check fails, we move on to the next + MatchResult check = checker.Check(clause->lhs, cand); + if (check == MatchResult::kClash) { + continue; + } + + // either success or we need to generate more candidates; + // either way, we're done with this candidate + failure = false; + if (check == MatchResult::kUnspecified) { + auto new_candidates = ExpandWildcards(clause->lhs, cand, mod); + for (auto candidate : new_candidates) { + candidates.push(candidate); + } + } + break; + } + + if (failure) { + failures.push_back(cand); + } + } + + return failures; +} + +// expose for testing only +TVM_REGISTER_API("relay._ir_pass.unmatched_cases") +.set_body_typed(const Match&, + const Module&)>([](const Match& match, + const Module& mod_ref) { + Module call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = ModuleNode::make({}, {}); + } + return UnmatchedCases(match, call_mod); + }); +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 5349532ca697..f1ca573d3e0e 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -74,28 +74,19 @@ * * The partial evaluator makes several assumptions, so there is room for improvement: * - * 0: The partial evaluator treats global variables as opaque. - * Doing PartialEval on a module level will solve this. - * - * 1: The partial evaluator assume all functions as terminating. - * We need to has a max_expand parameter that shrink on every compile time evaluation, - * to make sure PE does not infinite loop. - * Additionally, we might add a termination analysis pass that lift this requirement - * for function that analysis found terminating. - * - * 2: Every time an unknown effect happened, we clear the whole store. + * 0: Every time an unknown effect happened, we clear the whole store. * It is too conservative: if a local reference is created (and do not get passed outside), * An unknown global function call/global reference write can not modify it. * We can pair PE with escape analysis/alias analysis. * - * 3: We assume all unknown code has effect. Doing effect analysis can make the store more precise. + * 1: We assume all unknown code has effect. Doing effect analysis can make the store more precise. * - * 4: When doing pattern matching, we can simplify the match even for dynamic case. + * 2: When doing pattern matching, we can simplify the match even for dynamic case. * Right now it is all or nothing: either a complete match, or the original dynamic code. * Instead, we can get a match tree, pair it with the data and evaluate it to a normal form. * We then can reify the result. * - * 5: Every time a function is called, it's code will get expanded and partially evaluated. + * 3: Every time a function is called, its code will get expanded and partially evaluated. * We can do a binding time analysis to cache the result and avoid re-partial evaluation. * * These assumptions do not affect the correctness of the algorithm, however. @@ -104,11 +95,13 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" namespace tvm { namespace relay { +namespace partial_eval { using namespace runtime; @@ -132,6 +125,8 @@ struct VarEqual { } }; +Expr PostProcess(const Expr&); + /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +145,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -341,6 +346,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -375,23 +381,86 @@ DLContext CPUContext() { } FInterpreter CPUInterpreter() { - Target target = Target::create("llvm"); + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - BuildConfigContext fresh_build_ctx(build_config()); + With fresh_build_ctx(BuildConfig::Create()); return CreateInterpreter(Module(nullptr), CPUContext(), target); } +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = int; + +/*! + * \brief Annotate a function with a FuncId. + */ +struct WithFuncIdAttrs : public tvm::AttrsNode { + FuncId fid; + + TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { + TVM_ATTR_FIELD(fid) + .describe("The FuncId that an function is annotated with.") + .set_default(-1); + } +}; + +TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); + +Op WithFuncIdOp() { + static const Op& op = Op::Get("annotation.with_funcid"); + return op; +} + +Expr MkWithFuncId(const Expr& expr, FuncId fid) { + auto attrs = make_node(); + attrs->fid = fid; + return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("annotation.with_funcid") +.describe(R"code(Annotate a function with a funcid.)code" +TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("func", "Function", "The input data."); + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const CallNode* c = e.as()) { + CHECK(c->op.same_as(WithFuncIdOp())); + CHECK_EQ(c->args.size(), 1); + return AsFunc(c->args[0]); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +490,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -485,6 +567,10 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const CallNode* op, LetList* ll) final { + if (op->op.same_as(WithFuncIdOp())) { + CHECK_EQ(op->args.size(), 1); + return VisitExpr(op->args[0], ll); + } PStatic f = VisitExpr(op->op, ll); std::vector x; tvm::Array x_dyn; @@ -501,19 +587,40 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector