diff --git a/CMakeLists.txt b/CMakeLists.txt index 6500ba013e28f..c23d403bcb6a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,11 @@ file(GLOB_RECURSE NNVM_COMPILER_SRCS file(GLOB TOPI_SRCS topi/src/*.cc ) -file(GLOB_RECURSE HALIDEIR_SRCS 3rdparty/HalideIR/src/*.cpp) +file(GLOB_RECURSE HALIDEIR_SRCS + 3rdparty/HalideIR/src/base/*.cpp + 3rdparty/HalideIR/src/ir/*.cpp + 3rdparty/HalideIR/src/tvm/*.cpp +) list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS}) file(GLOB RUNTIME_SRCS src/runtime/*.cc diff --git a/Jenkinsfile b/Jenkinsfile index 53645eb14b280..c38ec5296bf35 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -309,6 +309,24 @@ stage('Integration Test') { } } +/* +stage('Build packages') { + parallel 'conda CPU': { + node('CPU') { + sh "${docker_run} tvmai/conda-cpu ./conda/build_cpu.sh + } + }, + 'conda cuda': { + node('CPU') { + sh "${docker_run} tvmai/conda-cuda90 ./conda/build_cuda.sh + sh "${docker_run} tvmai/conda-cuda100 ./conda/build_cuda.sh + } + } + // Here we could upload the packages to anaconda for releases + // and/or the master branch +} +*/ + stage('Deploy') { node('doc') { ws('workspace/tvm/deploy-docs') { diff --git a/apps/android_rpc/README.md b/apps/android_rpc/README.md index 38725917f424f..1f2a46a8589c7 100644 --- a/apps/android_rpc/README.md +++ b/apps/android_rpc/README.md @@ -52,9 +52,25 @@ cd apps/android_rpc gradle clean build ``` -In `app/build/outputs/apk` you'll find `app-release-unsigned.apk`, use `dev_tools/gen_keystore.sh` to generate a signature and use `dev_tools/sign_apk.sh` to get the signed apk file `app/build/outputs/apk/tvmrpc-release.apk`. +In `app/build/outputs/apk` you'll find `app-release-unsigned.apk`, use `dev_tools/gen_keystore.sh` to generate a signature and use `dev_tools/sign_apk.sh` to get the signed apk file `app/build/outputs/apk/release/tvmrpc-release.apk`. -Upload `tvmrpc-release.apk` to your Android device and install it. +Upload `tvmrpc-release.apk` to your Android device and install it: + +```bash +$ANDROID_HOME/platform-tools/adb install app/build/outputs/apk/release/tvmrpc-release.apk +``` + +If you see error: + + adb: failed to install app/build/outputs/apk/release/tvmrpc-release.apk: + Failure [INSTALL_FAILED_UPDATE_INCOMPATIBLE: + Package ml.dmlc.tvm.tvmrpc signatures do not match the previously installed version; ignoring!] + +Run uninstall first: + +```bash +$ANDROID_HOME/platform-tools/adb uninstall ml.dmlc.tvm.tvmrpc +``` ### Build with OpenCL diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index aef7629990c2f..548b69160b174 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -23,3 +23,7 @@ ifeq ($(USE_VULKAN), 1) APP_CPPFLAGS += -DTVM_VULKAN_RUNTIME=1 APP_LDFLAGS += -lvulkan endif + +ifeq ($(USE_SORT), 1) + APP_CPPFLAGS += -DUSE_SORT=1 +endif diff --git a/apps/android_rpc/app/src/main/jni/make/config.mk b/apps/android_rpc/app/src/main/jni/make/config.mk index c40ce4ba3ec7d..f61811bd604e4 100644 --- a/apps/android_rpc/app/src/main/jni/make/config.mk +++ b/apps/android_rpc/app/src/main/jni/make/config.mk @@ -22,6 +22,9 @@ USE_OPENCL = 0 # whether to enable Vulkan during compile USE_VULKAN = 0 +# whether to enable contrib sort functions during compile +USE_SORT = 1 + ifeq ($(USE_VULKAN), 1) # Statically linking vulkan requires API Level 24 or higher APP_PLATFORM = android-24 diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 60b41baaf8e70..aadc4d1884307 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -66,6 +66,10 @@ #include "../src/runtime/vulkan/vulkan_module.cc" #endif +#ifdef USE_SORT +#include "../src/contrib/sort/sort.cc" +#endif + #include diff --git a/conda/Dockerfile.template b/conda/Dockerfile.template index 59b9ac96814ee..1b5dc6fbef5e0 100644 --- a/conda/Dockerfile.template +++ b/conda/Dockerfile.template @@ -15,9 +15,13 @@ # specific language governing permissions and limitations # under the License. -FROM nvidia/cuda:{{ cuda_version }}-devel-centos6 +FROM nvidia/cuda:{{ cuda_version }}-devel-ubuntu16.04 -RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v{{ cudnn_short_version }}/cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -O && \ tar --no-same-owner -xzf cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz -C /usr/local && \ rm cudnn-{{ cuda_version }}-linux-x64-v{{ cudnn_version }}.tgz && \ ldconfig @@ -27,13 +31,16 @@ RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-lat chmod +x ~/miniconda.sh && \ ~/miniconda.sh -b -p /opt/conda && \ rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ /opt/conda/bin/conda install conda-build conda-verify && \ /opt/conda/bin/conda clean -ya +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + ENV PATH /opt/conda/bin:$PATH ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp WORKDIR /workspace RUN chmod -R a+w /workspace - -CMD conda build --output-folder /workspace/conda/pkg --variants '{cuda: True, cuda_version: {{ cuda_version }}}' /workspace/conda/tvm-libs diff --git a/conda/Makefile b/conda/Makefile deleted file mode 100644 index cda546ac73ce3..0000000000000 --- a/conda/Makefile +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -packages: - conda build tvm-libs - conda build tvm - conda build topi - conda built nnvm diff --git a/conda/nnvm/build.sh b/conda/build_cpu.sh old mode 100644 new mode 100755 similarity index 68% rename from conda/nnvm/build.sh rename to conda/build_cpu.sh index bdd333f57734c..992b1a369b96b --- a/conda/nnvm/build.sh +++ b/conda/build_cpu.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/sh # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,6 +17,15 @@ # under the License. set -e +set -u -cd nnvm/python -$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +# This is a fix for a weird bug in conda that makes it think +# it can't write in /tmp +HOME=/tmp +mkdir -p /tmp/.conda/pkgs +touch /tmp/.conda/pkgs/urls.txt +touch /tmp/.conda/environments.txt + + +conda build --output-folder=conda/pkg -c numba conda/tvm-libs +conda build --output-folder=conda/pkg -m conda/conda_build_config.yaml conda/tvm diff --git a/conda/topi/build.sh b/conda/build_cuda.sh old mode 100644 new mode 100755 similarity index 70% rename from conda/topi/build.sh rename to conda/build_cuda.sh index 4e5aafb937660..2c9a20ae66aec --- a/conda/topi/build.sh +++ b/conda/build_cuda.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/sh # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,6 +17,14 @@ # under the License. set -e +set -u -cd topi/python -$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +# This is a fix for a weird bug in conda that makes it think +# it can't write in /tmp +HOME=/tmp +mkdir -p /tmp/.conda/pkgs +touch /tmp/.conda/pkgs/urls.txt +touch /tmp/.conda/environments.txt + + +conda build --output-folder=conda/pkg --variants "{cuda: True, cuda_version: ${CUDA_VERSION%.*}}" -c numba conda/tvm-libs diff --git a/conda/topi/meta.yaml b/conda/cross-linux.cmake similarity index 54% rename from conda/topi/meta.yaml rename to conda/cross-linux.cmake index f4bc8950d4c49..360400267ae07 100644 --- a/conda/topi/meta.yaml +++ b/conda/cross-linux.cmake @@ -15,37 +15,24 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +# this one is important +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_PLATFORM Linux) +#this one not so much +set(CMAKE_SYSTEM_VERSION 1) -package: - name: topi - version: {{ version }} +# specify the cross compiler +set(CMAKE_C_COMPILER $ENV{CC}) -source: - path: ../.. +# where is the target environment +set(CMAKE_FIND_ROOT_PATH $ENV{PREFIX} $ENV{BUILD_PREFIX}/$ENV{HOST}/sysroot) -build: - number: 1 +# search for programs in the build host directories +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) -requirements: - host: - - python {{ python }} - - numpy - - setuptools - - decorator - - tvm-libs =={{ version }} - run: - - python - - {{ pin_compatible('numpy') }} - - decorator - - tvm-libs =={{ version }} - - tvm =={{ version }} +# for libraries and headers in the target directories +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) -test: - imports: - - topi - -about: - home: https://github.com/dmlc/tvm - license: Apache2 - summary: "TOPI: TVM Operator Inventory" +# god-awful hack because it seems to not run correct tests to determine this: +set(__CHAR_UNSIGNED___EXITCODE 1) diff --git a/conda/nnvm/meta.yaml b/conda/nnvm/meta.yaml deleted file mode 100644 index d948484a61e5f..0000000000000 --- a/conda/nnvm/meta.yaml +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -{% set version = "0.6.dev" %} - -package: - name: nnvm - version: {{ version }} - -source: - path: ../.. - -build: - number: 1 - skip: True # [win] - -requirements: - build: - - {{ compiler('cxx') }} - host: - - python {{ python }} - - cython - - numpy - - setuptools - - decorator - - tvm-libs =={{ version }} - run: - - tvm =={{ version }} - - topi =={{ version }} - - tvm-libs =={{ version }} - - python - - {{ pin_compatible('numpy') }} - - decorator - -test: - imports: - - nnvm - -about: - home: https://github.com/dmlc/nnvm - license: Apache2 - summary: Bring deep learning to bare metal diff --git a/conda/build_cuda.py b/conda/render_cuda.py similarity index 74% rename from conda/build_cuda.py rename to conda/render_cuda.py index 47af6ce4564e9..8057892fd83c1 100644 --- a/conda/build_cuda.py +++ b/conda/render_cuda.py @@ -29,8 +29,8 @@ # and from conda. # These two must be in sync -CUDNN_FULL_VERSION = '7.3.1.20' -CUDNN_VERSION = '7.3.1' +CUDNN_FULL_VERSION = '7.6.0.64' +CUDNN_VERSION = '7.6.0' condadir = os.path.dirname(sys.argv[0]) @@ -47,30 +47,15 @@ def render_dockerfile(version): cudnn_short_version=CUDNN_VERSION, cudnn_version=CUDNN_FULL_VERSION) fname = os.path.join(condadir, - 'Dockerfile.cuda' + version.replace('.', '')) + '../docker/Dockerfile.conda_cuda' + version.replace('.', '')) with open(fname, 'w') as f: f.write(txt) return fname -def build_docker(version): - vv = version.replace('.', '') - fname = render_dockerfile(version) - tagname = f'tvm-cuda{ vv }-forge' - subprocess.run(['docker', 'build', '-t', tagname, - condadir, '-f', fname], check=True) - return tagname - - -def build_pkg(version): - tagname = build_docker(version) - subprocess.run(['docker', 'run', '--rm', '-v', f'{ srcdir }:/workspace', - tagname], check=True) - - if __name__ == '__main__': build_versions = CUDA_VERSIONS if len(sys.argv) > 1: build_versions = sys.argv[1:] for version in build_versions: - build_pkg(version) + render_dockerfile(version) diff --git a/conda/tvm-libs/build.sh b/conda/tvm-libs/build.sh index e0b85910475ea..94919c60e7797 100644 --- a/conda/tvm-libs/build.sh +++ b/conda/tvm-libs/build.sh @@ -17,24 +17,37 @@ # under the License. set -e - -if [ "$cuda" == "True" ]; then - CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" -else - CUDA_OPT="" -fi +set -u if [ "$target_platform" == "osx-64" ]; then # macOS 64 bits - METAL_OPT="" # Conda can only target 10.9 for now + METAL_OPT="-DUSE_METAL=ON" + TOOLCHAIN_OPT="-DCMAKE_OSX_DEPLOYMENT_TARGET=10.11" else METAL_OPT="" + if [ "$target_platform" == "linux-64" ]; then + # Linux 64 bits + TOOLCHAIN_OPT="-DCMAKE_TOOLCHAIN_FILE=${RECIPE_DIR}/../cross-linux.cmake" + else + # Windows (or 32 bits, which we don't support) + TOOLCHAIN_OPT="" + fi +fi + +# When cuda is not set, we default to False +cuda=${cuda:-False} + +if [ "$cuda" == "True" ]; then + CUDA_OPT="-DUSE_CUDA=ON -DUSE_CUBLAS=ON -DUSE_CUDNN=ON" + TOOLCHAIN_OPT="" +else + CUDA_OPT="" fi rm -rf build || true mkdir -p build cd build -cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" .. +cmake $METAL_OPT $CUDA_OPT -DUSE_LLVM=$PREFIX/bin/llvm-config -DINSTALL_DEV=ON -DCMAKE_INSTALL_PREFIX="$PREFIX" $TOOLCHAIN_OPT .. make -j${CPU_COUNT} VERBOSE=1 make install cd .. diff --git a/conda/tvm-libs/meta.yaml b/conda/tvm-libs/meta.yaml index aad8f251c2a69..e3422a2174efe 100644 --- a/conda/tvm-libs/meta.yaml +++ b/conda/tvm-libs/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +{% set version = "0.6.dev1" %} package: name: tvm-libs @@ -25,21 +25,22 @@ source: path: ../.. build: - number: 1 - string: cuda{{ cuda_version }}_{{ PKG_BUILDNUM }} # [cuda] + number: 0 + string: cuda{{ cuda_version | replace('.', '') }}h{{ PKG_HASH }}_{{ PKG_BUILDNUM }} # [cuda] requirements: build: - # The OS X build will require some manual setup or it will break - # See https://docs.conda.io/projects/conda-build/en/latest/source/resources/compiler-tools.html#macos-sdk - - {{ compiler('cxx') }} - host: + # The anaconda compilers for OS X are old an annoying + # so we rely on the platform ones for now + - {{ compiler('cxx') }} # [linux] - cmake - - llvmdev ==6.0.0 + - make + host: + - llvmdev ==8.0.0 - zlib # [linux] run: - {{ pin_compatible('cudatoolkit', lower_bound=cuda_version, max_pin='x.x') }} # [cuda] - - {{ pin_compatible('cudnn', lower_bound='7.3.1', max_pin='x') }} # [cuda] + - {{ pin_compatible('cudnn', lower_bound='7.6.0', max_pin='x') }} # [cuda] about: home: https://github.com/dmlc/tvm diff --git a/conda/tvm/build.sh b/conda/tvm/build.sh index 6626aa5920914..494f90f0afa01 100644 --- a/conda/tvm/build.sh +++ b/conda/tvm/build.sh @@ -17,6 +17,16 @@ # under the License. set -e +set -u cd python $PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd .. + +cd topi/python +$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd ../.. + +cd nnvm/python +$PYTHON setup.py install --single-version-externally-managed --record=/tmp/record.txt +cd ../.. diff --git a/conda/tvm/meta.yaml b/conda/tvm/meta.yaml index 221dc7950f753..0daca4bcea2bd 100644 --- a/conda/tvm/meta.yaml +++ b/conda/tvm/meta.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -{% set version = "0.6.dev" %} +{% set version = "0.6.dev1" %} package: name: tvm @@ -25,7 +25,7 @@ source: path: ../.. build: - number: 1 + number: 0 requirements: build: @@ -46,6 +46,15 @@ requirements: test: imports: - tvm + - topi + - nnvm + requires: + - nose + - scipy + source_files: + - tests/python + commands: + - python -m nose -v tests/python/integration about: home: https://github.com/dmlc/tvm diff --git a/docker/Dockerfile.conda_cpu b/docker/Dockerfile.conda_cpu new file mode 100644 index 0000000000000..0660b5daa0e26 --- /dev/null +++ b/docker/Dockerfile.conda_cpu @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM ubuntu:16.04 + +RUN apt-get update && apt-get install -y bzip2 curl sudo binutils && rm -rf /var/lib/apt/lists/* + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +# Cache some of the packages for the builds +RUN /opt/conda/bin/conda install --download-only cmake make zlib && \ + /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 && \ + /opt/conda/bin/conda create -n py35 --download-only nose scipy numpy=1.11 cython decorator python=3.5 && \ + /opt/conda/bin/conda create -n py36 --download-only nose scipy numpy=1.11 cython decorator python=3.6 && \ + /opt/conda/bin/conda create -n py37 --download-only nose scipy numpy=1.11 cython decorator python=3.7 + +ENV PATH /opt/conda/bin:$PATH +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace diff --git a/docker/Dockerfile.conda_cuda100 b/docker/Dockerfile.conda_cuda100 new file mode 100644 index 0000000000000..d6e1cddbfd373 --- /dev/null +++ b/docker/Dockerfile.conda_cuda100 @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM nvidia/cuda:10.0-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-10.0-linux-x64-v7.6.0.64.tgz -O && \ + tar --no-same-owner -xzf cudnn-10.0-linux-x64-v7.6.0.64.tgz -C /usr/local && \ + rm cudnn-10.0-linux-x64-v7.6.0.64.tgz && \ + ldconfig + + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + +ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace \ No newline at end of file diff --git a/docker/Dockerfile.conda_cuda90 b/docker/Dockerfile.conda_cuda90 new file mode 100644 index 0000000000000..f55aa1bf2e126 --- /dev/null +++ b/docker/Dockerfile.conda_cuda90 @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM nvidia/cuda:9.0-devel-ubuntu16.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bzip2 curl sudo binutils && \ + rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL http://developer.download.nvidia.com/compute/redist/cudnn/v7.6.0/cudnn-9.0-linux-x64-v7.6.0.64.tgz -O && \ + tar --no-same-owner -xzf cudnn-9.0-linux-x64-v7.6.0.64.tgz -C /usr/local && \ + rm cudnn-9.0-linux-x64-v7.6.0.64.tgz && \ + ldconfig + + +RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda upgrade --all && \ + /opt/conda/bin/conda install conda-build conda-verify && \ + /opt/conda/bin/conda clean -ya + +RUN /opt/conda/bin/conda install --download-only cmake make zlib +RUN /opt/conda/bin/conda install --download-only -c numba llvmdev=8.0.0 + +ENV PATH /opt/conda/bin:$PATH +ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 +ENV CONDA_BLD_PATH /tmp + +WORKDIR /workspace +RUN chmod -R a+w /workspace \ No newline at end of file diff --git a/docker/Dockerfile.demo_android b/docker/Dockerfile.demo_android index e9c3e4f6ce8eb..4f93e84007495 100644 --- a/docker/Dockerfile.demo_android +++ b/docker/Dockerfile.demo_android @@ -52,6 +52,8 @@ ENV PATH ${PATH}:${VULKAN_SDK}/bin ENV LD_LIBRARY_PATH ${LD_LIBRARY_PATH}:${VULKAN_SDK}/lib ENV VK_LAYER_PATH ${VULKAN_SDK}/etc/explicit_layer.d +RUN git clone https://github.com/KhronosGroup/OpenCL-Headers /usr/local/OpenCL-Headers/ + # Build TVM RUN cd /usr && \ git clone --depth=1 https://github.com/dmlc/tvm --recursive && \ @@ -69,3 +71,4 @@ RUN cd /usr && \ # Environment variables ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/nnvm/python/:/usr/tvm/vta/python:${PYTHONPATH} +ENV ANDROID_HOME=/opt/android-sdk-linux/ diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 38cf36f237c7b..88d07cc884399 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -22,4 +22,4 @@ set -o pipefail # install libraries for python package on ubuntu pip2 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typing antlr4-python2-runtime attrs -pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs +pip3 install nose pylint==1.9.4 six numpy nose-timer cython decorator scipy tornado typed_ast pytest mypy orderedset antlr4-python3-runtime attrs requests Pillow diff --git a/docs/api/python/relay/index.rst b/docs/api/python/relay/index.rst index 39a68b6d1f5d5..90746b8e5d4ee 100644 --- a/docs/api/python/relay/index.rst +++ b/docs/api/python/relay/index.rst @@ -33,7 +33,8 @@ compiler stack. expr frontend image - ir_pass + analysis + transform module nn op diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 92f7399a89a57..446c4c0c19a91 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -623,12 +623,15 @@ IntSet Intersect(const Array& sets); * give the domain of each variables. Return undefined IntSet to * represent failure. * + * \note The returned set may be smaller than set that + * contains all possible values of v that satisfies the bound. + * * \param v The target variable to be deduced. * \param cond The conditional expression. * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, - * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * The deduce bound must implies e for all value in relax_map + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const Map& hint_map, @@ -641,7 +644,7 @@ IntSet DeduceBound(Expr v, Expr cond, * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const std::unordered_map& hint_map, diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index ed4ac5ea6a63f..1233e9b0b89b8 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -36,10 +36,11 @@ namespace tvm { // Internal node container Buffer class BufferNode; -/*! \brief memory access kind */ -enum class AccessMask : int { - kRead = 1, - kWrite = 2 +/*! \brief buffer type */ +enum BufferType : int { + kDefault = 1, + // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + kAutoBroadcast = 2, }; /*! @@ -129,6 +130,8 @@ class BufferNode : public Node { * elem_offset is guaranteed to be multiple of offset_factor. */ int offset_factor; + /*! \brief buffer type */ + BufferType buffer_type; /*! \brief constructor */ BufferNode() {} @@ -142,6 +145,7 @@ class BufferNode : public Node { v->Visit("scope", &scope); v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); + v->Visit("buffer_type", &buffer_type); } /*! \return preferred index type for this buffer node */ @@ -159,7 +163,8 @@ class BufferNode : public Node { std::string name, std::string scope, int data_alignment, - int offset_factor); + int offset_factor, + BufferType buffer_type); static constexpr const char* _type_key = "Buffer"; TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index e1c92e50e6ad1..98dbf6bb62906 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -27,7 +27,6 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include #include #include #include diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 9e4e00ca47ed4..2a6507b62a33d 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode { /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ - mutable int tag = -1; + mutable int32_t tag = -1; ConstructorNode() {} diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/analysis.h similarity index 55% rename from include/tvm/relay/pass.h rename to include/tvm/relay/analysis.h index 294d22b812a13..3672a22847dbf 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/analysis.h @@ -18,55 +18,21 @@ */ /*! - * \file tvm/relay/pass.h - * \brief The set of Relay passes written in C++. - */ -#ifndef TVM_RELAY_PASS_H_ -#define TVM_RELAY_PASS_H_ + * \file tvm/relay/analysis.h + * \brief The set of Relay analysis passes written in C++. + */ +#ifndef TVM_RELAY_ANALYSIS_H_ +#define TVM_RELAY_ANALYSIS_H_ -#include -#include +#include #include #include -#include #include -#include -#include -#include #include -#include namespace tvm { namespace relay { -/*! - * \brief Infer the type of an expression. - * - * The result of type checking is a new expression with unambigous - * type information filled in, as well as it's checked type field - * populated with the result type. - * - * \param expr The expression to type check. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return A type checked expression with its checked_type field populated. - */ -TVM_DLL Expr InferType(const Expr& expr, const Module& mod); - -/*! - * \brief Infer the type of a function as if it is mapped to var in the mod. - * - * \param f the function. - * \param mod The module used for referencing global functions. - * \param var The global variable corresponding to the function. - * - * \return A type checked Function with its checked_type field populated. - * \note this function mutates mod and is not thread-safe. - */ -TVM_DLL Function InferType(const Function& f, const Module& mod, - const GlobalVar& var); - /*! * \brief Check that types are well kinded by applying "kinding rules". * @@ -140,23 +106,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); */ TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); -/*! - * \brief Add abstraction over a function - * - * For example: `square` is transformed to - * `fun x -> square x`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion - * for more details. - * - * \param e The original function. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return the new function with abstraction - */ -TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); - /*! * \brief Check that each Var is only bound once. * @@ -288,87 +237,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); -/*! \brief Remove expressions which does not effect the program result. - * - * It will remove let bindings which are not referenced, - * and inline let bindings that are only used once. - * - * For example, this pass should turn `let a = 1 in 2` into `2`, - * as the value of the expression does not depend on a. - * - * As another example, `let a = 1 in a` will be optimized into 1, - * if the flag is turned on. - * - * \param e the expression to optimize. - * \param inline_once whether or not to inline binding used one. - * - * \return the optimized expression. - */ -TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); - -/*! - * \brief Fold constant expressions. - * - * \param expr the expression to be optimized. - * - * \return The optimized expression. - */ -TVM_DLL Expr FoldConstant(const Expr& expr); - -/*! - * \brief Fuse operations into expr into seperate functions. - * - * \param expr The expression. - * \param fuse_opt_level Optimization level. - * \param mod the module. - * - * \return The optimized expression. - */ -TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param expr The expression. - * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite - * rule function. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * \return The rewritten expression. - */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param expr The expression. - * \param rewrite_func The rewrite func that will apply to all operators. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The rewritten expression. - */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - -/*! - * \brief Rewrite the annotated program. - * - * \param expr The expression. - * \param fallback_device The fallback device which is the default device for - * operators without annotation. - * - * \return The updated program. - */ -TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); - /*! * \brief Collect the device mapping information of each expression. * @@ -387,38 +255,6 @@ TVM_DLL Map CollectDeviceInfo(const Expr& expr); */ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); -/*! - * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). - * - * It will turn an expression that is in a graph form (with sharing implicit), - * to an expression with explicit sharing (A-Normal Form). - * - * The scope of the root expression is the global scope. - * - * The scope of any non root expression is the least common ancestor of all it's scope. - * - * Values are ordered by post-DFS order in each scope. - * - * \param e the expression to observably share. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return expression in A-Normal Form. - */ -TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); - -/*! - * \brief Remove let binding and directly share via pointer instead. - * - * It will remove all let binding, - * and turn all of the variable bound by let into direct pointer reference. - * - * \param e the expression. - * - * \return the expression in graph normal form. - */ -TVM_DLL Expr ToGraphNormalForm(const Expr& e); - /*! * \brief Finds cases that the given match expression does not catch, if any. * @@ -431,30 +267,6 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); */ TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); -/*! - * \brief Aggressive constant propagation/constant folding/inlining. - * It will do as much computation in compile time as possible. - * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). - * As a side effect, code size will explode. - * - * \param e the expression - * \param mod the module - * - * \return the optimized expression. - */ -TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); - -/* - * \brief Bind function parameters or free variables. - * - * Parameter binding can only happen if expr is a Function. - * binds cannot change internal arguments of internal functions. - * - * \param expr The function to be binded. - * \param binds The map of arguments to - */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); - /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { /*! \brief Hash a Relay type. @@ -466,7 +278,6 @@ struct StructuralHash { * \return the hash value. */ size_t operator()(const Type& type) const; - /*! \brief Hash a Relay expression. * * Implements structural hashing of a Relay expression. @@ -478,20 +289,7 @@ struct StructuralHash { size_t operator()(const Expr& expr) const; }; -namespace vm { - -/*! - * \brief Compile a module, and construct the virtual machine. - * - * \param mod The module to compile. - * - * \return The constructed virtual machine. - */ -runtime::vm::VirtualMachine CompileModule(const Module& mod); - -} // namespace vm - } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_H_ +#endif // TVM_RELAY_ANALYSIS_H_ diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 68b7ccab99c7b..d05099f781acd 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - int tag; + int32_t tag; tvm::Array fields; @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int tag, + TVM_DLL static ConstructorValue make(int32_t tag, tvm::Array fields, Constructor construtor = {}); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 638f75968fd33..e888c54c17aca 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -55,7 +56,7 @@ struct Module; * The functional style allows users to construct custom * environments easily, for example each thread can store * a Module while auto-tuning. - * */ + */ class ModuleNode : public RelayNode { public: @@ -64,16 +65,12 @@ class ModuleNode : public RelayNode { /*! \brief A map from global type vars to ADT type data. */ tvm::Map type_definitions; - /*! \brief The entry function (i.e. "main"). */ - GlobalVar entry_func; - ModuleNode() {} void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("functions", &functions); v->Visit("type_definitions", &type_definitions); v->Visit("global_var_map_", &global_var_map_); - v->Visit("entry_func", &entry_func); v->Visit("global_type_var_map_", &global_type_var_map_); } @@ -118,6 +115,13 @@ class ModuleNode : public RelayNode { */ TVM_DLL void Remove(const GlobalVar& var); + /*! + * \brief Check if the global_var_map_ contains a global variable. + * \param name The variable name. + * \returns true if contains, otherise false. + */ + TVM_DLL bool ContainGlobalVar(const std::string& name) const; + /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. @@ -133,33 +137,40 @@ class ModuleNode : public RelayNode { TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; /*! - * \brief Lookup a global function by its variable. + * \brief Look up a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ TVM_DLL Function Lookup(const GlobalVar& var) const; /*! - * \brief Lookup a global function by its string name + * \brief Look up a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ TVM_DLL Function Lookup(const std::string& name) const; /*! - * \brief Lookup a global type definition by its variable. + * \brief Look up a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const; /*! - * \brief Lookup a global type definition by its name. + * \brief Look up a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const std::string& var) const; + /*! + * \brief Look up a constructor by its tag. + * \param tag The tag for the constructor. + * \return The constructor object. + */ + TVM_DLL Constructor LookupTag(const int32_t tag); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -172,10 +183,10 @@ class ModuleNode : public RelayNode { * Allows one to optionally pass a global function map as * well. * - * \param expr The expression to set as the entry point to the module. + * \param expr The expression to set as the main function to the module. * \param global_funcs The global function map. * - * \returns A module with expr set as the entry point. + * \returns A module with expr set as the main function. */ TVM_DLL static Module FromExpr( const Expr& expr, @@ -185,6 +196,9 @@ class ModuleNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); private: + /*! \brief Helper function for registering a typedef's constructors */ + void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); + /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -194,6 +208,11 @@ class ModuleNode : public RelayNode { * that ensures global uniqueness. */ tvm::Map global_type_var_map_; + + /*! \brief A map from constructor tags to constructor objects + * for convenient access + */ + std::unordered_map constructor_tag_map_; }; struct Module : public NodeRef { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 04b4e64dc9c3b..93129cf57a279 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -378,36 +378,6 @@ TVM_DLL Pass FoldConstant(); */ TVM_DLL Pass FuseOps(int fuse_opt_level = -1); -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite - * rule function. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The pass. - */ -TVM_DLL Pass ForwardRewrite(const std::string& rewrite_map_attr_name, - std::function fcontext = nullptr, - std::function - fmulti_ref_trigger = nullptr); - -/*! - * \brief Apply rewrite rules to rewrite the expr in post DFS order. - * - * \param rewrite_func The rewrite func that will apply to all operators. - * \param fcontext Additional callback to provide context argument for each call node. - * \param fmulti_ref_trigger Transformation function to be called when - * an Expr consumed by multiple callers. - * - * \return The pass. - */ -TVM_DLL Pass ForwardRewrite(const FForwardRewrite& rewrite_func, - std::function fcontext = nullptr, - std::function fmulti_ref_trigger = nullptr); - /*! * \brief Rewrite the annotated program. * @@ -434,6 +404,22 @@ TVM_DLL Pass RewriteAnnotatedOps(int fallback_device); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \return the pass. + */ +TVM_DLL Pass ToCPS(); + /*! * \brief Remove let binding and directly share via pointer instead. * @@ -541,7 +527,132 @@ TVM_DLL Pass AlterOpLayout(); */ TVM_DLL Pass CanonicalizeCast(); +/*! + * \brief Add abstraction over a function + * + * For example: `square` is transformed to + * `fun x -> square x`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion + * for more details. + * + * \return The pass. + */ +TVM_DLL Pass EtaExpand(); + } // namespace transform + +/*! + * \brief Bind the free variables to a Relay expression. This is a helper + * function usually called by other pass functions to help optimizations. + * + * \param expr The input expression. + * \param binds The variable to expression map that will be used to help the + * binding. + * + * \return The updated expression. + */ +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); + +/*! + * \brief Infer the type of a function as if it is mapped to var in the mod. + * + * \param f the function. + * \param mod The module used for referencing global functions. + * \param var The global variable corresponding to the function. + * + * \return A type checked Function with its checked_type field populated. + * \note this function mutates mod and is not thread-safe. + */ +TVM_DLL Function InferType(const Function& f, + const Module& mod, + const GlobalVar& var); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. This + * function is used as a helper function to rewrtie an expression in a pass. + * + * \param expr The expression. + * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite + * rule function. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * \return The rewritten expression. + */ +TVM_DLL Expr ForwardRewrite(const Expr& expr, + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. This + * function is used as a helper function to rewrtie an expression in a pass. + * + * \param expr The expression. + * \param rewrite_func The rewrite func that will apply to all operators. + * \param fcontext Additional callback to provide context argument for each call node. + * \param fmulti_ref_trigger Transformation function to be called when + * an Expr consumed by multiple callers. + * + * \return The rewritten expression. + */ +TVM_DLL Expr ForwardRewrite(const Expr& expr, + const FForwardRewrite& rewrite_func, + std::function fcontext = nullptr, + std::function fmulti_ref_trigger = nullptr); + +/*! + * \brief Rewrite the annotated program. + * + * \param expr The expression. + * \param fallback_device The fallback device which is the default device for + * operators without annotation. + * + * \return The updated program. + */ +TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device); + +/*! + * \brief Turn an expression into continuation passing style(CPS). + * + * CPS mean that every function will, instead of returning the result directly, + * be passed down an extra function (called the continuation) as argument, + * and pass the result to the continuation instead. + * + * Thus, every function call has to be passed an extra argument + * that represent the rest of the computation (Hence the name of continuation). + * + * Similarly, all other compute will be wrapped and call the continuation as well. + * + * \param f the function. + * \param mod the module. + * + * \return the converted Function. + */ +TVM_DLL Function ToCPS(const Function& f, const Module& mod); + +/*! + * \brief Remove the continuation argument of a CPS function. + * + * Note that this only transform the type back into un-CPS form + * when there is no higher order input/output. + * + * \param f the function. + * + * \return the converted Function. + */ +TVM_DLL Function UnCPS(const Function& f); + +/*! + * \brief Deduplicate the bound variables and type variables in the expression. + * + * \param e the expression. + * + * \return the deduplicated expression. + */ +TVM_DLL Expr DeDup(const Expr& e); + } // namespace relay } // namespace tvm diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 82b3dd4695415..17fd626ee51d0 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -962,10 +962,10 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } - if (GetCustomTypeRegistered(t.code)) { - os << "custom[" << GetCustomTypeName(t.code) << "]"; - } else { + if (t.code < kCustomBegin) { os << TypeCode2Str(t.code); + } else { + os << "custom[" << GetCustomTypeName(t.code) << "]"; } if (t.code == kHandle) return os; os << static_cast(t.bits); @@ -987,10 +987,10 @@ inline std::string TVMType2String(TVMType t) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } - if (GetCustomTypeRegistered(t.code)) { - repr += "custom[" + GetCustomTypeName(t.code) + "]"; - } else { + if (t.code < kCustomBegin) { repr += TypeCode2Str(t.code); + } else { + repr += "custom[" + GetCustomTypeName(t.code) + "]"; } if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); diff --git a/nnvm/tests/python/compiler/test_to_relay.py b/nnvm/tests/python/compiler/test_to_relay.py index e79831d06cf26..dac14a8c1f220 100644 --- a/nnvm/tests/python/compiler/test_to_relay.py +++ b/nnvm/tests/python/compiler/test_to_relay.py @@ -18,7 +18,7 @@ from nnvm import testing from nnvm import to_relay import tvm -from tvm.relay import ir_pass +from tvm.relay import transform from tvm.relay import create_executor from tvm.contrib import graph_runtime import numpy as np @@ -41,10 +41,11 @@ def check_model(sym, shapes, dtypes, params): nnvm_rts.run(**inputs) nnvm_out = nnvm_rts.get_output(0) relay_model, params = to_relay.to_relay(net, shapes, dtypes, params) - relay_model = ir_pass.infer_type(relay_model) - relay_rts = create_executor(kind='graph', ctx=tvm.cpu(0), target='llvm') + mod = tvm.relay.Module.from_expr(relay_model) + mod = transform.InferType()(mod) + relay_rts = create_executor(kind='graph', mod=mod, ctx=tvm.cpu(0), target='llvm') inputs.update(params) - relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values())) + relay_out = relay_rts.evaluate()(*list(inputs.values())) np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy()) # def test_mlp(): diff --git a/python/tvm/api.py b/python/tvm/api.py index d88f06170543c..e4777b6e39649 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -531,7 +531,8 @@ def decl_buffer(shape, elem_offset=None, scope="", data_alignment=-1, - offset_factor=0): + offset_factor=0, + buffer_type=""): """Declare a new symbolic buffer. Normally buffer is created automatically during lower and build. @@ -574,11 +575,39 @@ def decl_buffer(shape, If 0 is pssed, the alignment will be set to 1. if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. + buffer_type: str, optional, {"", "auto_broadcast"} + auto_broadcast buffer allows one to implement broadcast computation + without considering whether dimension size equals to one. + TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. + Returns ------- buffer : Buffer The created buffer + Example + ------- + Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, + + .. code-block:: python + + m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") + n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") + o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2") + A = tvm.placeholder((m0, m1, m2), name='A') + B = tvm.placeholder((n0, n1, n2), name='B') + C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') + Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast") + Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast") + s = tvm.create_schedule(C.op) + fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) + ctx = tvm.cpu(0) + a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + Note ---- Buffer data structure reflects the DLTensor structure in dlpack. @@ -602,7 +631,7 @@ def decl_buffer(shape, data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, - data_alignment, offset_factor) + data_alignment, offset_factor, buffer_type) def layout(layout_str): """Create a layout node from a string. diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index 252882d17eceb..cffd42347b35d 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -142,7 +142,7 @@ def __init__(self, graph, input_shapes, records, target_ops, # Generate workload and schedule dictionaries. if isinstance(graph, relay.Module): - graph = graph[graph.entry_func] + graph = graph["main"] if isinstance(graph, relay.expr.Function): node_dict = {} diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index c0debaedede0d..5d07bd3fbce5f 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -21,6 +21,7 @@ import topi from tvm import relay, autotvm +from tvm.relay import transform from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple from tvm.relay.ty import TupleType, TensorType from tvm.autotvm.task import TaskExtractEnv @@ -80,6 +81,14 @@ def expr2graph(expr, target_ops, node_dict, node_list): task_pos += 1 +def _infer_type(node): + """A method to infer the type of a relay expression.""" + mod = relay.Module.from_expr(node) + mod = transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, relay.Function) else entry.body + + def _expr2graph_impl(expr, target_ops, node_dict, node_list): """Implementation to convert relay expr to graph data structure """ @@ -99,7 +108,7 @@ def _traverse_expr(node): node_entry["inputs"] += node_list[in_node_idx]["inputs"] else: node_entry["inputs"].append([in_node_idx, 0, 0]) - infer_out = relay.ir_pass.infer_type(node) + infer_out = _infer_type(node) out_type = infer_out._checked_type_ if isinstance(out_type, TensorType): node_entry["types"].append(out_type) @@ -127,10 +136,10 @@ def _traverse_expr(node): free_var = relay.Var("var_%d" % i, input_type) params.append(free_var) call = relay.Call(node.op, params, node.attrs) - func = relay.Function(params, call) + mod = relay.Module.from_expr(relay.Function(params, call)) relay.backend.compile_engine.get().clear() build_thread = threading.Thread(target=relay.build, - args=(func, + args=(mod, "llvm -device=tracing", None, None)) @@ -168,7 +177,7 @@ def _traverse_expr(node): node_dict[node] = node_index node_list.append(node_entry) - relay.ir_pass.post_order_visit(expr, _traverse_expr) + relay.analysis.post_order_visit(expr, _traverse_expr) def get_direct_ancestor(node_list, visited_dict, target_ops, node_idx, input_names): diff --git a/python/tvm/autotvm/graph_tuner/utils/utils.py b/python/tvm/autotvm/graph_tuner/utils/utils.py index 6151734299af6..b9777ef844595 100644 --- a/python/tvm/autotvm/graph_tuner/utils/utils.py +++ b/python/tvm/autotvm/graph_tuner/utils/utils.py @@ -17,6 +17,7 @@ # pylint: disable=eval-used,invalid-name,too-many-arguments """Utility functions""" from tvm import relay +from tvm.relay import transform def has_multiple_inputs(node_list, node_idx, input_names): @@ -107,4 +108,7 @@ def bind_inputs(expr, input_shapes=None, input_dtypes="float32"): rebind_dict[var] = updated_input_dict[var.name_hint] updated_expr = relay.expr.bind(expr, rebind_dict) - return relay.ir_pass.infer_type(updated_expr) + mod = relay.Module.from_expr(updated_expr) + mod = transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(updated_expr, relay.Function) else entry.body diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index d80443a208d66..5b0294ef2d07d 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -105,8 +105,9 @@ def extract_from_program(func, params, ops, target, target_host=None): relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems + mod = relay.Module.from_expr(func) build_thread = threading.Thread(target=_build, - args=(func, + args=(mod, target, target_host, params)) @@ -183,8 +184,9 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None): for func, param in zip(funcs, params): relay.backend.compile_engine.get().clear() # wrap build call in thread to avoid multiprocessing problems + mod = relay.Module.from_expr(func) build_thread = threading.Thread(target=my_build, - args=(func, + args=(mod, target, target_host, params)) diff --git a/python/tvm/make.py b/python/tvm/make.py index 7439952ad7adb..241edd6b09481 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -24,7 +24,6 @@ """ from __future__ import absolute_import as _abs from ._ffi.function import _init_api -from ._ffi.runtime_ctypes import TVMType def range_by_min_extent(min_value, extent): @@ -48,35 +47,6 @@ def range_by_min_extent(min_value, extent): return _range_by_min_extent(min_value, extent) -def static_cast(dtype, expr): - """Cast expr to dtype. - - If expr is scalar and dtype is a corresponding vector - type, a Broadcast is generated. Otherwise it is a Cast. - - Parameters - ---------- - dtype : str - The target data type. - - expr : Expr - The expression to be casted. - - Returns - ------- - casted : Expr - The casted expression. - """ - target_type = TVMType(dtype) - src_type = TVMType(expr.dtype) - if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits: - if src_type.lanes == target_type.lanes: - return expr - if src_type.lanes == 1 and target_type.lanes > 1: - return Broadcast(expr, target_type.lanes) - return Cast(dtype, expr) - - def node(type_key, **kwargs): """Make a new DSL node by its type key and fields diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 5536e503e6b67..dfac85bb1ed28 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -24,7 +24,7 @@ from . import expr_functor from . import module from . import adt -from . import ir_pass +from . import analysis from . import transform from .build_module import build, create_executor from .transform import build_config @@ -32,6 +32,7 @@ from . import parser from . import debug from . import param_dict +from . import feature # Root operators from .op import Op @@ -101,7 +102,7 @@ bind = expr.bind module_pass = transform.module_pass function_pass = transform.function_pass -alpha_equal = ir_pass.alpha_equal +alpha_equal = analysis.alpha_equal # ExprFunctor ExprFunctor = expr_functor.ExprFunctor @@ -122,3 +123,6 @@ ModulePass = transform.ModulePass FunctionPass = transform.FunctionPass Sequential = transform.Sequential + +# Feature +Feature = feature.Feature diff --git a/python/tvm/relay/_ir_pass.py b/python/tvm/relay/_analysis.py similarity index 89% rename from python/tvm/relay/_ir_pass.py rename to python/tvm/relay/_analysis.py index 3a0e0ac846b99..32a7324ae29f5 100644 --- a/python/tvm/relay/_ir_pass.py +++ b/python/tvm/relay/_analysis.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI exposing the Relay type inference and checking.""" +"""FFI exposing the passes for Relay program analysis.""" from tvm._ffi.function import _init_api -_init_api("relay._ir_pass", __name__) +_init_api("relay._analysis", __name__) diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi deleted file mode 100644 index 13035bb36f716..0000000000000 --- a/python/tvm/relay/_ir_pass.pyi +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -from . import ir -from .env import Module - -def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... -def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... -def _get_checked_type(expr: ir.Expr) -> ir.Type: ... -def well_formed(expr: ir.Expr) -> bool: ... -def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... diff --git a/python/tvm/relay/analysis.py b/python/tvm/relay/analysis.py new file mode 100644 index 0000000000000..ee8ce985fcbc0 --- /dev/null +++ b/python/tvm/relay/analysis.py @@ -0,0 +1,363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains the set of passes for Relay, which exposes an interface for +configuring the passes and scripting them in Python. +""" +from . import _analysis +from . import _make +from .expr import Expr +from .ty import Type +from .module import Module +from .feature import Feature + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _analysis.post_order_visit(expr, fvisit) + + +def well_formed(expr): + """Check that each Var is only bound once (well formed). + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + well_form : bool + Whether the input expression is well formed + """ + return _analysis.well_formed(expr) + + +def check_kind(t, mod=None): + """Check that the type is well kinded and return the kind. + For example, this mean type cannot has tensor of tensor, or is a tuple type + of 2 shapes. + + Parameters + ---------- + t : tvm.relay.Type + The type to check + + mod : Optional[tvm.relay.Module] + The global module. + + Returns + ------- + kind : Kind + the kind of t + + Examples + -------- + .. code:: python + + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape + assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type + """ + if mod is not None: + return _analysis.check_kind(t, mod) + else: + return _analysis.check_kind(t) + + +def free_vars(expr): + """Get free Vars from expression expr in Post DFS order. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of free variables in post DFS order. + + Note + ---- + The fact that Vars are post-DFS ordred are useful in + neural networks: usually this means weights of previous + are ordered first. + """ + return _analysis.free_vars(expr) + + +def bound_vars(expr): + """Get bound vars from expression expr in post-DFS order. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of bound variables in post-DFS order. + """ + return _analysis.bound_vars(expr) + + +def all_vars(expr): + """Get all vars from expression expr in post-DFS order. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of all variables in post-DFS order. + """ + return _analysis.all_vars(expr) + + +def free_type_vars(expr, mod=None): + """Get free type variables from expression/type e + + Parameters + ---------- + expr : Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + mod : Optional[tvm.relay.Module] + The global module + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of free type variables in post-DFS order + """ + use_mod = mod if mod is not None else Module() + return _analysis.free_type_vars(expr, use_mod) + + +def bound_type_vars(expr, mod=None): + """Get bound type variables from expression/type e + + Parameters + ---------- + expr : Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + mod : Optional[tvm.relay.Module] + The global module + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of bound type variables in post-DFS order + """ + use_mod = mod if mod is not None else Module() + return _analysis.bound_type_vars(expr, use_mod) + + +def all_type_vars(expr, mod=None): + """Get all type variables from expression/type e + + Parameters + ---------- + expr : Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + mod : Optional[tvm.relay.Module] + The global module + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of all type variables in post-DFS order + """ + use_mod = mod if mod is not None else Module() + return _analysis.all_type_vars(expr, use_mod) + + +def alpha_equal(lhs, rhs): + """Compare two Relay expr for structural equivalence (alpha equivalence). + + Parameters + ---------- + lhs : tvm.relay.Expr + One of the input Expression. + + rhs : tvm.relay.Expr + One of the input Expression. + + Returns + ------- + result : bool + True iff lhs is alpha equal to rhs. + """ + return bool(_make._alpha_equal(lhs, rhs)) + + +def graph_equal(lhs, rhs): + """Compare two Relay expr for data-flow equivalence. + The difference between this and alpha-equality is that + variables are not expected to match between lhs and rhs; + they are treated as sources and are mapped between each other. + + Parameters + ---------- + lhs : tvm.relay.Expr + One of the input Expression. + + rhs : tvm.relay.Expr + One of the input Expression. + + Returns + ------- + result : bool + True iff lhs is data-flow equivalent to rhs. + """ + return bool(_make._graph_equal(lhs, rhs)) + + +def collect_device_info(expr): + """Collect the device allocation map for the given expression. The device + ids are propagated from the `device_copy` operators. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + ret : Dict[tvm.relay.expr, int] + A dictionary mapping tvm.relay.Expr to device type. + """ + return _analysis.CollectDeviceInfo(expr) + + +def collect_device_annotation_ops(expr): + """Collect the device annotation ops for the given expression. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + ret : Dict[tvm.relay.expr, int] + A dictionary mapping tvm.relay.Expr to device type where the keys are + annotation expressions. + """ + return _analysis.CollectDeviceAnnotationOps(expr) + + +def get_total_mac_number(expr): + """ + Count the number of MACs (multiply-accumulate) of a model + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + result : int64 + The number of MACs (multiply-accumulate) of a model + """ + return _analysis.GetTotalMacNumber(expr) + + +def unmatched_cases(match, mod=None): + """ + Finds cases that the match expression does not catch, if any. + + Parameters + ---------- + match : tvm.relay.Match + The match expression + + mod : Optional[tvm.relay.Module] + The module (defaults to an empty module) + + Returns + ------- + missing_patterns : [tvm.relay.Pattern] + Patterns that the match expression does not catch. + """ + return _analysis.unmatched_cases(match, mod) + + +def detect_feature(a, b=None): + """ + Detect the feature used in a relay program. + + Parameters + ---------- + a : Union[tvm.relay.Expr, tvm.relay.Module] + The input expression or module. + + b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]] + The input expression or module. + The two arguments cannot both be expression or module. + + Returns + ------- + features : Set[Feature] + Features used in the program. + """ + if isinstance(a, Module): + a, b = b, a + return set([Feature(int(x)) for x in _analysis.detect_feature(a, b)]) + + +def structural_hash(value): + """Hash a Relay expression structurally. + + Parameters + ---------- + expr : Union[tvm.relay.Expr, tvm.relay.Type] + The expression to hash. + + Returns + ------- + result : int + The hash value + """ + if isinstance(value, Expr): + return int(_analysis._expr_hash(value)) + elif isinstance(value, Type): + return int(_analysis._type_hash(value)) + else: + msg = ("found value of type {0} expected" + + "relay.Expr or relay.Type").format(type(value)) + raise TypeError(msg) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index c54a65b78fb23..462dda9488c21 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -21,7 +21,7 @@ import numpy as np from . import _backend -from .. import _make, ir_pass, transform +from .. import _make, analysis, transform from .. import module from ... import register_func, nd from ..base import NodeBase, register_relay_node @@ -114,17 +114,18 @@ def __init__(self, value): _make.RefValue, value) -def _arg_to_ast(arg): +def _arg_to_ast(mod, arg): if isinstance(arg, TensorValue): return Constant(arg.data.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): - return Tuple([_arg_to_ast(field) for field in arg.fields]) + return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, tuple): - return Tuple([_arg_to_ast(field) for field in arg]) + return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, RefValue): - return RefCreate(_arg_to_ast(arg.value)) + return RefCreate(_arg_to_ast(mod, arg.value)) elif isinstance(arg, ConstructorValue): - return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields]) + return Call(mod.get_constructor(arg.tag), + [_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, np.ndarray): return Constant(nd.array(arg)) elif isinstance(arg, Constant): @@ -163,6 +164,8 @@ def _convert_args(self, expr, args, kwargs): args: List[tvm.NDArray] The new arguments with all keyword arguments placed in the correct slot. """ + assert expr is not None + if not kwargs: return args @@ -229,7 +232,7 @@ def evaluate(self, expr=None, binds=None): if binds: scope_builder = ScopeBuilder() for key, value in binds.items(): - scope_builder.let(key, _arg_to_ast(value)) + scope_builder.let(key, _arg_to_ast(self.mod, value)) scope_builder.ret(expr) expr = scope_builder.get() @@ -237,7 +240,7 @@ def evaluate(self, expr=None, binds=None): return self._make_executor() if isinstance(expr, Function): - assert not ir_pass.free_vars(expr) + assert not analysis.free_vars(expr) if isinstance(expr, (Function, GlobalVar)): return self._make_executor(expr) @@ -286,29 +289,29 @@ def _make_executor(self, expr=None): assert self.mod is not None def _interp_wrapper(*args, **kwargs): if expr is None: - args = self._convert_args(self.mod[self.mod.entry_func], args, kwargs) + args = self._convert_args(self.mod["main"], args, kwargs) else: args = self._convert_args(expr, args, kwargs) relay_args = [] for arg in args: - relay_args.append(_arg_to_ast(arg)) + relay_args.append(_arg_to_ast(self.mod, arg)) # Set the entry function for the module. if expr is None: pass elif isinstance(expr, GlobalVar): - self.mod[self.mod.entry_func] = self.mod[expr] + self.mod["main"] = self.mod[expr] else: assert isinstance(expr, Function) func = Function([], Call(expr, relay_args)) relay_args = [] if self.mod: - self.mod[self.mod.entry_func] = func + self.mod["main"] = func else: self.mod = module.Module.from_expr(func) mod = self.optimize() - opt_expr = Call(mod[self.mod.entry_func.name_hint], relay_args) + opt_expr = Call(mod["main"], relay_args) return self._intrp(opt_expr) return _interp_wrapper diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index ceb403fe77174..152ee576e7bdb 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -45,7 +45,7 @@ def optimize(mod): ret : tvm.relay.Module The optimized module. """ - main_func = mod[mod.entry_func] + main_func = mod["main"] opt_passes = [] if not main_func.params and isinstance(main_func.body, GlobalVar): @@ -134,8 +134,8 @@ def _make_executor(self, expr=None): expr = expr if expr else self.mod assert expr, "either expr or self.mod should be not null." if isinstance(expr, Expr): - self.mod[self.mod.entry_func] = expr - main = self.mod[self.mod.entry_func] + self.mod["main"] = expr + main = self.mod["main"] def _vm_wrapper(*args, **kwargs): args = self._convert_args(main, args, kwargs) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index cdda17aa517b6..404829f74cf78 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -25,7 +25,6 @@ from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt from . import _build_module -from . import ir_pass from . import ty as _ty from . import expr as _expr from .module import Module as _Module @@ -178,7 +177,7 @@ def build(mod, target=None, target_host=None, params=None): The parameters of the final graph. """ if isinstance(mod, _Module): - func = mod[mod.entry_func] + func = mod["main"] elif isinstance(mod, _expr.Function): func = mod warnings.warn( @@ -227,23 +226,23 @@ class GraphExecutor(_interpreter.Executor): """ def __init__(self, mod, ctx, target): + assert mod is not None self.mod = mod self.ctx = ctx self.target = target def _make_executor(self, expr=None): - if not expr: - assert self.mod, "either expr or self.mod should be not null." - expr = self.mod[self.mod.entry_func] - ret_type = ir_pass.infer_type(expr).ret_type + if expr: + self.mod["main"] = expr + ret_type = self.mod["main"].checked_type.ret_type num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1 - graph_json, mod, params = build(expr, target=self.target) + graph_json, mod, params = build(self.mod, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: gmodule.set_input(**params) def _graph_wrapper(*args, **kwargs): - args = self._convert_args(expr, args, kwargs) + args = self._convert_args(self.mod["main"], args, kwargs) # Create map of inputs. for i, arg in enumerate(args): gmodule.set_input(i, arg) @@ -280,6 +279,8 @@ def create_executor(kind="debug", target : :py:class:`tvm.Target` The corresponding context """ + if mod is None: + mod = _Module() if ctx is not None: assert ctx.device_type == _nd.context(str(target), 0).device_type else: diff --git a/python/tvm/relay/expr.pyi b/python/tvm/relay/expr.pyi index b7395c365390a..d264e99e05770 100644 --- a/python/tvm/relay/expr.pyi +++ b/python/tvm/relay/expr.pyi @@ -19,7 +19,7 @@ from typing import List import tvm from .base import Span, NodeBase from .ty import Type, TypeParam -from ._ir_pass import _get_checked_type +from ._analysis import _get_checked_type class Expr(NodeBase): @@ -128,4 +128,4 @@ class If(Expr): def __init__(self, cond, true_value, false_value): # type: (Expr, Expr, Expr) -> None - ... \ No newline at end of file + ... diff --git a/python/tvm/relay/frontend/caffe2.py b/python/tvm/relay/frontend/caffe2.py index 18489b380ee71..43d9d21c09b5d 100644 --- a/python/tvm/relay/frontend/caffe2.py +++ b/python/tvm/relay/frontend/caffe2.py @@ -18,7 +18,7 @@ """Caffe2 frontend""" from __future__ import absolute_import as _abs import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -450,8 +450,8 @@ def from_caffe2(self, init_net, predict_net): else: outputs = out[0] - func = _expr.Function(ir_pass.free_vars(outputs), outputs) - self._mod[self._mod.entry_func] = func + func = _expr.Function(analysis.free_vars(outputs), outputs) + self._mod["main"] = func return self._mod, self._params diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index efd198803c2b6..c5057f35fedef 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -19,8 +19,8 @@ import logging from topi.util import get_const_tuple from .. import expr as _expr -from .. import expr as _expr -from .. import ir_pass +from .. import module as _module +from .. import transform as _transform from .. import op as _op @@ -407,9 +407,17 @@ def get_name(node): name = node.name_hint return name + +def infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = _transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, _expr.Function) else entry.body + def infer_shape(inputs): """A method to get the output shape of an intermediate node in the graph.""" - out_type = ir_pass.infer_type(inputs) + out_type = infer_type(inputs) out_shapes = get_const_tuple(out_type.checked_type.shape) return out_shapes @@ -417,7 +425,7 @@ def infer_channels(inputs, transpose=False): """A hack for getting 'channels' or 'units' since caffe2 does not provide these attributes. We check the shape of weights provided to get the number. """ - out_type = ir_pass.infer_type(inputs) + out_type = infer_type(inputs) out_shapes = [get_const_tuple(out_type.checked_type.shape)] channels = out_shapes[0][0] if not transpose else out_shapes[0][1] return channels diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 1cac547d07c95..e7b129e66724b 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -19,7 +19,7 @@ from __future__ import absolute_import as _abs import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -462,6 +462,6 @@ def from_coreml(model, shape=None): for o in spec.description.output] # for now return first output outexpr = outexpr[0] - func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) + func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 7b26ed5692df7..f452146ae46cc 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -23,7 +23,7 @@ from enum import Enum import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .common import get_relay_op, new_var @@ -820,7 +820,7 @@ def from_darknet(self): outputs = _as_list(sym) + self._outs outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - sym = _expr.Function(ir_pass.free_vars(outputs), outputs) + sym = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(sym), self._tvmparams def from_darknet(net, diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index ad033f9bf3260..91da87c84b809 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -20,7 +20,7 @@ import sys import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -743,6 +743,6 @@ def _convert_input_layer(keras_layer): outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ for oc in model._output_coordinates] outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) - func = _expr.Function(ir_pass.free_vars(outexpr), outexpr) + func = _expr.Function(analysis.free_vars(outexpr), outexpr) params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0bcee63ad3e8c..e40f1dea61a9f 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -20,7 +20,7 @@ import json import tvm -from .. import ir_pass +from .. import analysis, transform from .. import expr as _expr from .. import op as _op from .. import module as _module @@ -41,6 +41,13 @@ "relu" : _op.nn.relu } +def _infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, _expr.Function) else entry.body + def _mx_fully_connected(inputs, attrs): import mxnet as mx units = attrs.get_int("num_hidden") @@ -89,7 +96,8 @@ def _stable_softrelu(x): def _mx_compare(new_op, wrapper): def impl(inputs, attrs): - dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype + expr = _infer_type(inputs[0]) + dtype = expr.checked_type.dtype return wrapper(new_op)(inputs, attrs).astype(dtype) return impl @@ -258,7 +266,8 @@ def _mx_slice_like(inputs, attrs): def _mx_slice_axis(inputs, attrs): assert len(inputs) == 1 - shape = ir_pass.infer_type(inputs[0]).checked_type.shape + expr = _infer_type(inputs[0]) + shape = expr.checked_type.shape axis = attrs.get_int("axis") ax_beg = attrs.get_int("begin") ax_end = attrs.get_str("end") @@ -302,7 +311,8 @@ def _mx_crop_like(inputs, attrs): if offset == (0, 0): new_attrs["axes"] = (2, 3) return _op.slice_like(*inputs, **new_attrs) - like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape + expr = _infer_type(inputs[1]) + like_shape = expr.checked_type.shape new_attrs['begin'] = [0, 0, offset[0], offset[1]] new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], offset[1]+like_shape[3]] @@ -532,7 +542,8 @@ def _mx_resize(inputs, attrs): scale_width = attrs.get_float("scale_width", None) height = attrs.get_int("height", 1) width = attrs.get_int("width", 1) - shape = ir_pass.infer_type(inputs[0]).checked_type.shape + expr = _infer_type(inputs[0]) + shape = expr.checked_type.shape if scale_height is not None: height = (scale_height * shape[2]).astype("int32") if scale_width is not None: @@ -639,7 +650,8 @@ def _mx_broadcast_axis(inputs, attrs): assert len(axis) == len(size) if len(axis) == 0: return inputs[0] - src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape + expr = _infer_type(inputs[0]) + src_shape = expr.checked_type.shape tgt_shape = [] for i, dim in enumerate(src_shape): if i not in axis: @@ -734,7 +746,8 @@ def _rnn_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias, activati return out, [out] def _gru_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): - dtype = ir_pass.infer_type(data).checked_type.dtype + expr = _infer_type(data) + dtype = expr.checked_type.dtype i2h = _op.nn.bias_add(_op.nn.dense(data, i2h_weight), i2h_bias, axis=-1) h2h = _op.nn.bias_add(_op.nn.dense(states[0], h2h_weight), h2h_bias, axis=-1) i2h_r, i2h_z, i2h = _op.split(i2h, indices_or_sections=3, axis=1) @@ -776,7 +789,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] init_states = inputs[2:] - data_shape = ir_pass.infer_type(seq_data).checked_type.shape + expr = _infer_type(seq_data) + data_shape = expr.checked_type.shape seq_len = int(data_shape[0]) assert len(concat_weight) == num_layers * 4 * direct @@ -1099,7 +1113,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, mod=None): outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return func @@ -1186,5 +1200,5 @@ def from_mxnet(symbol, else: msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol)) raise ValueError(msg) - mod[mod.entry_func] = func + mod["main"] = func return mod, params diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index bb968ec0bea8a..397ca90de55f2 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -22,7 +22,7 @@ import numpy as np import tvm from ... import nd as _nd -from .. import ir_pass +from .. import analysis from .. import transform as _transform from .. import expr as _expr from .. import module as _module @@ -412,7 +412,7 @@ def _impl_v1(cls, inputs, attr, params): else: data, shape = inputs logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") - shape_params = ir_pass.free_vars(shape) + shape_params = analysis.free_vars(shape) func = _expr.Function(shape_params, shape) mod = _module.Module.from_expr(func) seq = _transform.Sequential([_transform.InferType(), @@ -1106,7 +1106,7 @@ def from_onnx(self, graph, opset): # now return the outputs outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(func), self._params def _parse_value_proto(self, value_proto): diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d754e85ef78d7..59e0983e95985 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -27,7 +27,8 @@ import tvm from topi.util import get_const_tuple -from .. import ir_pass +from .. import analysis +from .. import transform as _transform from .. import expr as _expr from .. import op as _op from ..expr_functor import ExprMutator @@ -38,9 +39,9 @@ def _infer_value(input_val, params): from tvm.contrib import graph_runtime # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in ir_pass.free_vars( + assert all(var.name_hint in params.keys() for var in analysis.free_vars( input_val)), "All inputs to infer must be available in params." - func = _expr.Function(ir_pass.free_vars(input_val), input_val) + func = _expr.Function(analysis.free_vars(input_val), input_val) with tvm.relay.build_config(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ctx = tvm.context("llvm", 0) @@ -235,9 +236,16 @@ def _infer_out_shapes(inputs, params): """A method to get the output shape of intermediate nodes in the relay graph.""" return [_infer_shape(inputs, params)] +def _infer_type(node): + """A method to infer the type of an intermediate node in the relay graph.""" + mod = _module.Module.from_expr(node) + mod = _transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, _expr.Function) else entry.body + def _infer_shape(node, params=None): """A method to get the output shape of an intermediate node in the relay graph.""" - out_type = ir_pass.infer_type(node) + out_type = _infer_type(node) return get_const_tuple(out_type.checked_type.shape) def _get_param(params, input_node): @@ -1841,7 +1849,8 @@ def _while_loop(self): bind_map = {} for i, var in enumerate(self.loop_vars): if not isinstance(var, _expr.Var): - var_type = ir_pass.infer_type(var).checked_type + var_chk = _infer_type(var) + var_type = var_chk.checked_type else: var_type = var.type_annotation @@ -2112,8 +2121,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _expr.Function(ir_pass.free_vars(out), out) - self._mod[self._mod.entry_func] = func + func = _expr.Function(analysis.free_vars(out), out) + self._mod["main"] = func return self._mod, self._params def _parse_import_prerequisites(self, graph): @@ -2329,7 +2338,8 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() - self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + chk_op = _infer_type(op[0]) + self._branches[node_name_prefix].cond = chk_op elif node.op == "NextIteration": op = self._nodes[node.input[0]] assert len(op) == 1 diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index fe163871fa60f..bf1938b1481e7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -20,7 +20,7 @@ import math import numpy as np import tvm -from .. import ir_pass +from .. import analysis from .. import expr as _expr from .. import module as _module from .. import op as _op @@ -914,5 +914,5 @@ def from_tflite(model, shape_dict, dtype_dict): params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - func = _expr.Function(ir_pass.free_vars(outputs), outputs) + func = _expr.Function(analysis.free_vars(outputs), outputs) return _module.Module.from_expr(func), params diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py deleted file mode 100644 index 1748571cb3163..0000000000000 --- a/python/tvm/relay/ir_pass.py +++ /dev/null @@ -1,704 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=no-else-return -# pylint: disable=unidiomatic-typecheck -""" -This file contains the set of passes for Relay, which exposes an interface for -configuring the passes and scripting them in Python. -""" -from . import _ir_pass -from . import _make -from .expr import Expr -from .ty import Type -from .module import Module -from .feature import Feature - - -def post_order_visit(expr, fvisit): - """Recursively visit the ir in post DFS order node, - apply fvisit. Each node is guaranteed to be visited - only once. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - fvisit : function - The visitor function to be applied. - """ - return _ir_pass.post_order_visit(expr, fvisit) - -def infer_type(expr, mod=None): - """Infer the type of expr under the context of mod. - - Parameters - ---------- - expr: tvm.relay.Expr - The input expression. - - mod: Optional[tvm.relay.Module] - The global module. - - Returns - ------- - checked_expr : tvm.relay.Expr - The checked expression. - """ - return _ir_pass.infer_type(expr, mod) - - -def backward_fold_scale_axis(expr): - """Backward fold axis scaling into weights of conv2d/dense. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - - Returns - ------- - folded_expr : tvm.relay.Expr - The folded expression after transformation. - - Note - ---- - It is recommended to call backward_fold_scale_axis - before using forward_fold_scale_axis. - As backward folding targets common conv-bn pattern. - """ - return _ir_pass.backward_fold_scale_axis(expr) - -def eta_expand(expr, mod): - """Add abstraction over a function. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - mod : tvm.relay.Module - The global module. - - Returns - ------- - expanded_expr : tvm.relay.Expr - The expression after eta expansion. - """ - return _ir_pass.eta_expand(expr, mod) - -def forward_fold_scale_axis(expr): - """Fold the scaling of axis into weights of conv2d/dense. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - - Returns - ------- - folded_expr : tvm.relay.Expr - The folded expression after transformation. - - Note - ---- - It is recommended to call backward_fold_scale_axis - before using forward_fold_scale_axis. - As backward folding targets common conv-bn pattern. - """ - return _ir_pass.forward_fold_scale_axis(expr) - - -def well_formed(expr): - """Check that each Var is only bound once (well formed). - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - - Returns - ------- - well_form : bool - Whether the input expression is well formed - """ - return _ir_pass.well_formed(expr) - - -def check_kind(t, mod=None): - """Check that the type is well kinded and return the kind. - For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. - - Parameters - ---------- - t : tvm.relay.Type - The type to check - - mod : Optional[tvm.relay.Module] - The global module. - - Returns - ------- - kind : Kind - the kind of t - - Examples - -------- - .. code:: python - - assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape - assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type - """ - if mod is not None: - return _ir_pass.check_kind(t, mod) - else: - return _ir_pass.check_kind(t) - - -def free_vars(expr): - """Get free Vars from expression expr in Post DFS order. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - - Returns - ------- - free : List[tvm.relay.Var] - The list of free variables in post DFS order. - - Note - ---- - The fact that Vars are post-DFS ordred are useful in - neural networks: usually this means weights of previous - are ordered first. - """ - return _ir_pass.free_vars(expr) - - -def bound_vars(expr): - """Get bound vars from expression expr in post-DFS order. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - - Returns - ------- - free : List[tvm.relay.Var] - The list of bound variables in post-DFS order. - """ - return _ir_pass.bound_vars(expr) - - -def all_vars(expr): - """Get all vars from expression expr in post-DFS order. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - - Returns - ------- - free : List[tvm.relay.Var] - The list of all variables in post-DFS order. - """ - return _ir_pass.all_vars(expr) - - -def free_type_vars(expr, mod=None): - """Get free type variables from expression/type e - - Parameters - ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] - The input expression/type - - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - free : List[tvm.relay.TypeVar] - The list of free type variables in post-DFS order - """ - use_mod = mod if mod is not None else Module() - return _ir_pass.free_type_vars(expr, use_mod) - - -def bound_type_vars(expr, mod=None): - """Get bound type variables from expression/type e - - Parameters - ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] - The input expression/type - - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - free : List[tvm.relay.TypeVar] - The list of bound type variables in post-DFS order - """ - use_mod = mod if mod is not None else Module() - return _ir_pass.bound_type_vars(expr, use_mod) - - -def all_type_vars(expr, mod=None): - """Get all type variables from expression/type e - - Parameters - ---------- - expr : Union[tvm.relay.Expr,tvm.relay.Type] - The input expression/type - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - free : List[tvm.relay.TypeVar] - The list of all type variables in post-DFS order - """ - use_mod = mod if mod is not None else Module() - return _ir_pass.all_type_vars(expr, use_mod) - - -def simplify_inference(expr): - """ Simplify the data-flow graph for inference phase. - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - Returns - ------- - result : tvm.relay.Expr - An expression which is semantically equal to the input expression, - but with some simplification - """ - return _ir_pass.simplify_inference(expr) - - -def canonicalize_ops(expr): - """ Canonicalize special operators to basic operators. - This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - Returns - ------- - result : tvm.relay.Expr - An expression without bias_add - """ - return _ir_pass.canonicalize_ops(expr) - - -def dead_code_elimination(expr, inline_once=False): - """ Remove expressions which does not effect the program result (dead code). - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - inline_once : Optional[Bool] - Whether to inline binding that occur only once. - Returns - ------- - result : tvm.relay.Expr - An expression which is semantically equal to the input expression, - but with dead code removed. - """ - return _ir_pass.dead_code_elimination(expr, inline_once) - - -def alpha_equal(lhs, rhs): - """Compare two Relay expr for structural equivalence (alpha equivalence). - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - - Returns - ------- - result : bool - True iff lhs is alpha equal to rhs. - """ - return bool(_make._alpha_equal(lhs, rhs)) - - -def graph_equal(lhs, rhs): - """Compare two Relay expr for data-flow equivalence. - The difference between this and alpha-equality is that - variables are not expected to match between lhs and rhs; - they are treated as sources and are mapped between each other. - - Parameters - ---------- - lhs : tvm.relay.Expr - One of the input Expression. - - rhs : tvm.relay.Expr - One of the input Expression. - - Returns - ------- - result : bool - True iff lhs is data-flow equivalent to rhs. - """ - return bool(_make._graph_equal(lhs, rhs)) - - -def structural_hash(value): - """Hash a Relay expression structurally. - - Parameters - ---------- - expr : Union[tvm.relay.Expr, tvm.relay.Type] - The expression to hash. - - Returns - ------- - result : int - The hash value - """ - if isinstance(value, Expr): - return int(_ir_pass._expr_hash(value)) - elif isinstance(value, Type): - return int(_ir_pass._type_hash(value)) - else: - msg = ("found value of type {0} expected" + - "relay.Expr or relay.Type").format(type(value)) - raise TypeError(msg) - - -def fold_constant(expr): - """Fold the constant expression in expr. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - transformed_expr : tvm.relay.Expr - The transformed expression. - """ - return _ir_pass.FoldConstant(expr) - - -def fuse_ops(expr, opt_level=1, mod=None): - """Fuse operators in expr together. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - opt_level : int - The level of fuse optimization. - - mod : tvm.relay.Module - The module to perform fusion over. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression, containing fused result. - """ - return _ir_pass.FuseOps(expr, opt_level, mod) - - -def combine_parallel_conv2d(expr, min_num_branches=3): - """Combine multiple conv2d into one. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - min_num_branches : int - The minimum number of parallel branches when the transformation should be applied. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression - """ - return _ir_pass.CombineParallelConv2D(expr, min_num_branches) - - -def alter_op_layout(expr): - """Alternate the layouts of operators or replace primitive operators with - other expressions. - This pass can be used for computing convolution in custom layouts or - other general weight pre-transformation. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression with alternated layout. - """ - return _ir_pass.AlterOpLayout(expr) - - -def rewrite_annotated_ops(expr, fallback_device): - """Rewrite the annotated program where annotation operators, e.g. - `on_deivce`, mark which device an expression should be scheduled to. - This pass helps heterogeneous execution where different operators may need - to be allocated on various devices. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - fallback_device : int - The fallback device type. It is also used as the default device for - operators with no annotated device. - - Returns - ------- - transformed_expr : tvm.relay.Expr - Transformed expression with cross device data copy operators. - """ - return _ir_pass.RewriteDeviceAnnotation(expr, fallback_device) - - -def collect_device_info(expr): - """Collect the device allocation map for the given expression. The device - ids are propagated from the `device_copy` operators. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.expr, int] - A dictionary mapping tvm.relay.Expr to device type. - """ - return _ir_pass.CollectDeviceInfo(expr) - - -def collect_device_annotation_ops(expr): - """Collect the device annotation ops for the given expression. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - ret : Dict[tvm.relay.expr, int] - A dictionary mapping tvm.relay.Expr to device type where the keys are - annotation expressions. - """ - return _ir_pass.CollectDeviceAnnotationOps(expr) - - -def to_a_normal_form(expr, mod=None): - """ - Turn Graph Normal Form expression into A Normal Form Expression. - - The scope of the root expression is the global scope. - - The scope of any non root expression is the least common ancestor of all it's scope. - - Values are ordered by post-DFS order in each scope. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module. - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.to_a_normal_form(expr, mod) - - -def to_graph_normal_form(expr): - """Turn A Normal Form expression into Graph Normal Form expression - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - Returns - ------- - result : tvm.relay.Expr - The output expression - """ - return _ir_pass.to_graph_normal_form(expr) - - -def gradient(expr, mod=None, mode='higher_order'): - """ - Transform the input function, - returning a function that calculate the original result, - paired with gradient of the input. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, which is a Function or a GlobalVar. - - mod : Optional[tvm.relay.Module] - - mode : Optional[String] - The mode of the automatic differentiation algorithm. - 'first_order' only work on first order code, but will not produce reference nor closure. - 'higher_order' work on all code using reference and closure. - - Returns - ------- - expr : tvm.relay.Expr - The transformed expression. - """ - if mode == 'first_order': - return _ir_pass.first_order_gradient(expr, mod) - elif mode == 'higher_order': - return _ir_pass.gradient(expr, mod) - else: - raise Exception('unknown mode') - - -def get_total_mac_number(expr): - """ - Count the number of MACs (multiply-accumulate) of a model - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - Returns - ------- - result : int64 - The number of MACs (multiply-accumulate) of a model - """ - return _ir_pass.GetTotalMacNumber(expr) - - -def eliminate_common_subexpr(expr, fskip=None): - """ - Eliminate common subexpressions. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - fskip : function - The callback function that decides whether an expression should be skipped. - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.eliminate_common_subexpr(expr, fskip) - - -def partial_evaluate(expr, mod=None): - """ - Evaluate the static fragment of the code. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.partial_evaluate(expr, mod) - - -def unmatched_cases(match, mod=None): - """ - Finds cases that the match expression does not catch, if any. - - Parameters - ---------- - match : tvm.relay.Match - The match expression - mod : Optional[tvm.relay.Module] - The module (defaults to an empty module) - - Returns - ------- - missing_patterns : [tvm.relay.Pattern] - Patterns that the match expression does not catch. - """ - return _ir_pass.unmatched_cases(match, mod) - - -def detect_feature(a, b=None): - """ - Detect the feature used in a relay program. - - Parameters - ---------- - a : Union[tvm.relay.Expr, tvm.relay.Module] - The input expression or module. - - b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]] - The input expression or module. - The two arguments cannot both be expression or module. - - Returns - ------- - features : Set[Feature] - Features used in the program. - """ - if isinstance(a, Module): - a, b = b, a - return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)]) diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 138dfa8822154..8ac15f743fc4f 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -33,7 +33,7 @@ class Module(RelayNode): Parameters ---------- - functions : dict, optional. + functions: Optional[dict]. Map of global var to Function """ def __init__(self, functions=None, type_definitions=None): @@ -78,17 +78,11 @@ def __setitem__(self, var, val): def _add(self, var, val, update=False): if isinstance(val, _expr.Expr): if isinstance(var, _base.string_types): - var = _expr.GlobalVar(var) - - # TODO(@jroesch): Port this logic to C++. - if not isinstance(val, _expr.Function): - if isinstance(val, _expr.GlobalVar): - val = ir_pass.eta_expand(val, self) + if _module.Module_ContainGlobalVar(self, var): + var = _module.Module_GetGlobalVar(self, var) else: - val = _expr.Function([], val) - - - _make.Module_Add(self, var, val, update) + var = _expr.GlobalVar(var) + _module.Module_Add(self, var, val, update) else: assert isinstance(val, _ty.Type) if isinstance(var, _base.string_types): @@ -100,7 +94,7 @@ def __getitem__(self, var): Parameters ---------- - var: str or GlobalVar + var: Union[String, GlobalVar, GlobalTypeVar] The name or global variable. Returns @@ -165,6 +159,25 @@ def get_global_type_var(self, name): """ return _module.Module_GetGlobalTypeVar(self, name) + def get_constructor(self, tag): + """Look up an ADT constructor by tag. + + Parameters + ---------- + tag: int + The tag for a constructor. + + Returns + ------- + constructor: Constructor + The constructor associated with the given tag, + + Raises + ------ + tvm.TVMError if the corresponding constructor cannot be found. + """ + return _module.Module_LookupTag(self, tag) + @staticmethod def from_expr(expr): return _module.Module_FromExpr(expr) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 7bce9dd3c5b99..1de86173040d0 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -137,6 +137,12 @@ def conv2d_transpose(data, dilation : Tuple[int], optional Specifies the dilation rate to be used for dilated convolution. + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + groups : int, optional Number of groups for grouped convolution. diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index fa70e19544677..beebceaf8590c 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,7 @@ from . import _quantize from .. import expr as _expr from .. import module as _module -from .. import ir_pass as _ir_pass +from .. import analysis as _analysis from .. import transform as _transform from .. import op as _op from ... import make as _make @@ -250,7 +250,7 @@ def _make_const(val): const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) - _ir_pass.post_order_visit(graph, visit_func) + _analysis.post_order_visit(graph, visit_func) return _expr.bind(graph, const_params) @@ -365,4 +365,4 @@ def quantize(graph, params=None, dataset=None): mod = optimize(mod) mod = quantize_seq(mod) - return mod[mod.entry_func.name_hint] + return mod["main"] diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 7a5007bbfb8f2..de9e55b369d19 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -17,6 +17,9 @@ """Utilities for testing and benchmarks""" from __future__ import absolute_import as _abs +import tvm.relay as relay +from tvm.relay import transform + from . import mlp from . import resnet from . import dqn @@ -32,3 +35,15 @@ from .config import ctx_list from .init import create_workload from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def run_infer_type(expr): + return run_opt_pass(expr, transform.InferType()) diff --git a/python/tvm/relay/testing/dcgan.py b/python/tvm/relay/testing/dcgan.py index 4ee0bd13a5a7e..c6b258badb5b6 100644 --- a/python/tvm/relay/testing/dcgan.py +++ b/python/tvm/relay/testing/dcgan.py @@ -81,7 +81,7 @@ def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None, dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv") tanh = relay.tanh(dc64) - args = relay.ir_pass.free_vars(tanh) + args = relay.analysis.free_vars(tanh) return relay.Function(args, tanh) @@ -103,8 +103,8 @@ def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype= Returns ------- - net : nnvm.symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains a DCGAN network. params : dict of str to NDArray The parameters. """ diff --git a/python/tvm/relay/testing/densenet.py b/python/tvm/relay/testing/densenet.py index de3ebe36eb7bd..f9b479153bfad 100644 --- a/python/tvm/relay/testing/densenet.py +++ b/python/tvm/relay/testing/densenet.py @@ -79,7 +79,7 @@ def _make_dense_net(num_init_features, growth_rate, block_config, ret = layers.dense_add_bias(flat, units=classes, name='dense') - return relay.Function(relay.ir_pass.free_vars(ret), ret) + return relay.Function(relay.analysis.free_vars(ret), ret) def get_workload(densenet_size=121, classes=1000, batch_size=4, image_shape=(3, 224, 224), dtype='float32'): @@ -105,8 +105,8 @@ def get_workload(densenet_size=121, classes=1000, batch_size=4, Returns ------- - net: relay.Function - The computation graph representing densenet. + mod: tvm.relay.Module + The relay module that contains a DenseNet network. params : dict of str to NDArray The benchmark paraeters. diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 034ac0a6c2e5f..cdf9d24af996a 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -54,7 +54,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" relu4 = relay.nn.relu(dense1) dense2 = layers.dense_add_bias(relu4, units=num_actions, name="dense2") - args = relay.ir_pass.free_vars(dense2) + args = relay.analysis.free_vars(dense2) return relay.Function(args, dense2) @@ -72,8 +72,8 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - net : nnvm.symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains a DQN network. params : dict of str to NDArray The parameters. """ diff --git a/python/tvm/relay/testing/inception_v3.py b/python/tvm/relay/testing/inception_v3.py index c9ec3293ed0a1..4da543257c318 100644 --- a/python/tvm/relay/testing/inception_v3.py +++ b/python/tvm/relay/testing/inception_v3.py @@ -266,7 +266,7 @@ def get_net(batch_size, fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes) fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1) inception_v3 = relay.nn.softmax(data=fc1) - args = relay.ir_pass.free_vars(inception_v3) + args = relay.analysis.free_vars(inception_v3) return relay.Function(args, inception_v3) def get_workload(batch_size=1, num_classes=1000, @@ -289,8 +289,8 @@ def get_workload(batch_size=1, num_classes=1000, Returns ------- - net : nnvm.Symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains an Inception V3 network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/init.py b/python/tvm/relay/testing/init.py index b246b46172766..0b8ab2b42029b 100644 --- a/python/tvm/relay/testing/init.py +++ b/python/tvm/relay/testing/init.py @@ -144,16 +144,16 @@ def create_workload(net, initializer=None, seed=0): Returns ------- - net : tvm.relay.Function - The updated dataflow + mod : tvm.relay.Module + The created relay module. params : dict of str to NDArray The parameters. """ - net = relay.ir_pass.infer_type(net) + mod = relay.Module.from_expr(net) + mod = relay.transform.InferType()(mod) shape_dict = { - v.name_hint : v.checked_type for v in net.params} - net.astext() + v.name_hint : v.checked_type for v in mod["main"].params} np.random.seed(seed) initializer = initializer if initializer else Xavier() params = {} @@ -163,4 +163,4 @@ def create_workload(net, initializer=None, seed=0): init_value = np.zeros(v.concrete_shape).astype(v.dtype) initializer(k, init_value) params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0)) - return net, params + return mod, params diff --git a/python/tvm/relay/testing/lstm.py b/python/tvm/relay/testing/lstm.py index b0915e033ccbf..d0134c1a864d6 100644 --- a/python/tvm/relay/testing/lstm.py +++ b/python/tvm/relay/testing/lstm.py @@ -154,7 +154,7 @@ def get_net(iterations, num_hidden, batch_size=1, dtype="float32"): builder.ret(out) body = builder.get() - args = relay.ir_pass.free_vars(body) + args = relay.analysis.free_vars(body) return relay.Function(args, body, input_type) @@ -173,8 +173,8 @@ def get_workload(iterations, num_hidden, batch_size=1, dtype="float32"): The data type Returns ------- - net : nnvm.symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains a LSTM network. params : dict of str to NDArray The parameters. """ diff --git a/python/tvm/relay/testing/mlp.py b/python/tvm/relay/testing/mlp.py index 562ef21ba9f1c..337bde5d5889e 100644 --- a/python/tvm/relay/testing/mlp.py +++ b/python/tvm/relay/testing/mlp.py @@ -58,7 +58,7 @@ def get_net(batch_size, fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes) fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1) mlp = relay.nn.softmax(data=fc3) - args = relay.ir_pass.free_vars(mlp) + args = relay.analysis.free_vars(mlp) return relay.Function(args, mlp) @@ -84,8 +84,8 @@ def get_workload(batch_size, Returns ------- - net : relay.Function - The dataflow. + mod : tvm.relay.Module + The relay module that contains a mlp network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/mobilenet.py b/python/tvm/relay/testing/mobilenet.py index 78e1d82456c84..3b068c05a24ed 100644 --- a/python/tvm/relay/testing/mobilenet.py +++ b/python/tvm/relay/testing/mobilenet.py @@ -108,7 +108,7 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224), weight = relay.var('fc_weight') fc = relay.nn.dense(data=flatten, weight=weight, units=num_classes) softmax = relay.nn.softmax(data=fc) - return relay.Function(relay.ir_pass.free_vars(softmax), softmax) + return relay.Function(relay.analysis.free_vars(softmax), softmax) def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtype='float32'): @@ -130,8 +130,8 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp Returns ------- - net : relay.Function - The computational graph + mod : tvm.relay.Module + The relay module that contains a MobileNet network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 9ba57ae09ef5b..a8e369b740219 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -169,7 +169,7 @@ def resnet(units, flat = relay.nn.batch_flatten(data=pool1) fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1') net = relay.nn.softmax(data=fc1) - return relay.Function(relay.ir_pass.free_vars(net), net) + return relay.Function(relay.analysis.free_vars(net), net) def get_net(batch_size, @@ -261,8 +261,8 @@ def get_workload(batch_size=1, Returns ------- - net : relay.Function - The computational graph + mod : tvm.relay.Module + The relay module that contains a ResNet network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/squeezenet.py b/python/tvm/relay/testing/squeezenet.py index c7b8e8db166b6..1e9ea73e9360e 100644 --- a/python/tvm/relay/testing/squeezenet.py +++ b/python/tvm/relay/testing/squeezenet.py @@ -119,7 +119,7 @@ def get_net(batch_size, image_shape, num_classes, version, dtype): net = relay.nn.global_avg_pool2d(net) net = relay.nn.batch_flatten(net) net = relay.nn.softmax(net) - args = relay.ir_pass.free_vars(net) + args = relay.analysis.free_vars(net) return relay.Function(args, net) @@ -149,8 +149,8 @@ def get_workload(batch_size=1, Returns ------- - net : nnvm.Symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains a SqueezeNet network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/testing/vgg.py b/python/tvm/relay/testing/vgg.py index bec141f70ffd0..205c5b1fa8e39 100644 --- a/python/tvm/relay/testing/vgg.py +++ b/python/tvm/relay/testing/vgg.py @@ -90,7 +90,7 @@ def get_net(batch_size, image_shape, num_classes, dtype, num_layers=11, batch_no feature = get_feature(data, layers, filters, batch_norm) classifier = get_classifier(feature, num_classes) symbol = relay.nn.softmax(data=classifier) - args = relay.ir_pass.free_vars(symbol) + args = relay.analysis.free_vars(symbol) return relay.Function(args, symbol) @@ -124,8 +124,8 @@ def get_workload(batch_size, Returns ------- - net : nnvm.Symbol - The computational graph + mod : tvm.relay.Module + The relay module that contains a VGG network. params : dict of str to NDArray The parameters. diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 5f47e5b446aa7..2805e0b429fa0 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -277,6 +277,40 @@ def FoldScaleAxis(): return _transform.FoldScaleAxis() +def BackwardFoldScaleAxis(): + """Backward fold axis scaling into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to backward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. + """ + return _transform.BackwardFoldScaleAxis() + + +def ForwardFoldScaleAxis(): + """Fold the scaling of axis into weights of conv2d/dense. + + Returns + ------- + ret : tvm.relay.Pass + The registered pass to forward fold expressions. + + Note + ---- + It is recommended to call backward_fold_scale_axis + before using forward_fold_scale_axis. + As backward folding targets common conv-bn pattern. + """ + return _transform.ForwardFoldScaleAxis() + + def SimplifyInference(): """Simplify the data-flow graph for inference phase. An simplified expression which is semantically equal to the input expression will be returned. @@ -302,15 +336,20 @@ def CanonicalizeOps(): return _transform.CanonicalizeOps() -def DeadCodeElimination(): - """ Remove expressions which does not effect the program result (dead code). +def DeadCodeElimination(inline_once=False): + """Remove expressions which does not effect the program result (dead code). + + Parameters + ---------- + inline_once: Optional[Bool] + Whether to inline binding that occurs only once. Returns ------- ret: tvm.relay.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _transform.DeadCodeElimination() + return _transform.DeadCodeElimination(inline_once) def FoldConstant(): @@ -401,11 +440,26 @@ def ToANormalForm(): Returns ------- - ret: tvm.relay.Pass + ret: Union[tvm.relay.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ return _transform.ToANormalForm() + +def ToCPS(expr, mod=None): + """ + Turn expression into continuation passing style(CPS). + + Every intermediate compute will be passed to a continuation. + + Returns + ------- + result: tvm.relay.Pass + The registered pass that transforms an expression into CPS. + """ + return _transform.to_cps(expr, mod) + + def EtaExpand(): """Add abstraction over a function @@ -416,6 +470,7 @@ def EtaExpand(): """ return _transform.EtaExpand() + def ToGraphNormalForm(): """Turn A Normal Form expression into Graph Normal Form expression @@ -447,13 +502,21 @@ def EliminateCommonSubexpr(fskip=None): def PartialEvaluate(): """Evaluate the static fragment of the code. + Note + ---- + This transformation could be either `Module -> Module` or `Expr -> Expr`. + It will directly transform the input expression to a new one if the target + expression is provided. Otherwise, it will rely on the pass manager to + carry out transformation. + Returns ------- - ret : tvm.relay.Pass + ret: tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ return _transform.PartialEvaluate() + def CanonicalizeCast(): """ Canonicalize cast expressions to make operator fusion more efficient. @@ -465,6 +528,80 @@ def CanonicalizeCast(): """ return _transform.CanonicalizeCast() + +def gradient(expr, mod=None, mode='higher_order'): + """ + Transform the input function, + returning a function that calculate the original result, + paired with gradient of the input. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression, which is a Function or a GlobalVar. + + mod : Optional[tvm.relay.Module] + + mode : Optional[String] + The mode of the automatic differentiation algorithm. + 'first_order' only works on first order code, but will not produce + reference nor closure. + 'higher_order' works on all code using reference and closure. + + Returns + ------- + expr : tvm.relay.Expr + The transformed expression. + """ + if mode == 'first_order': + return _transform.first_order_gradient(expr, mod) + if mode == 'higher_order': + return _transform.gradient(expr, mod) + raise Exception('unknown mode') + + +def to_cps(func, mod=None): + """ + Turn expression into CPS expression. + + Every intermediate compute will be passed to a continuation. + + Parameters + ---------- + func: tvm.relay.Function + The input function. + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + result: tvm.relay.Function + The output function. + """ + return _transform.to_cps(func, mod) + + +def un_cps(func): + """ + Turn an cps function into a Function without the continuation argument. + + Note that this will not give the exact same interface as before cps: + If the input/output is higher order, they will still be in cps form. + + Parameters + ---------- + func: tvm.relay.Function + The input function + + Returns + ------- + result: tvm.relay.Function + The output function + """ + return _transform.un_cps(func) + + def _wrap_class_module_pass(pass_cls, pass_info): """Wrap a python class as function pass""" class PyModulePass(ModulePass): diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 42d60b85e375f..00ac715e8c075 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -207,7 +207,13 @@ TVM_REGISTER_API("Range") }); TVM_REGISTER_API("_Buffer") -.set_body_typed(BufferNode::make); +.set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator std::string(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8], type); + }); TVM_REGISTER_API("_BufferAccessPtr") .set_body_method(&Buffer::access_ptr); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 2198aee934787..626fc18c57df9 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -106,6 +106,7 @@ bool Analyzer::CanProve(const Expr& expr) { Expr Analyzer::Simplify(const Expr& expr) { if (is_const(expr)) return expr; auto res = this->rewrite_simplify(expr); + if (is_const(res)) return res; res = this->canonical_simplify(res); return res; } diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 395a371f43af7..003ba8def7612 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor { void Deduce(); void Visit(const NodeRef& e) final { - if (!success) return; + if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { - success = false; + success_ = false; return; } } @@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor { void Visit_(const Add* op) final { bool left = op->a.get() == path_[iter_]; - result -= left ? op->b : op->a; + result_ -= left ? op->b : op->a; Visit(left ? op->a : op->b); } void Visit_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { - result += op->b; + result_ += op->b; } else { - result -= op->a; - result = - result; - is_greater = !is_greater; + result_ -= op->a; + result_ = - result_; + is_greater_ = !is_greater_; } Visit(left ? op->a : op->b); } @@ -130,43 +130,65 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; + Expr target_var = left ? op->a : op->b; - SignType sign; + SignType sign_operand; if (operand.type().is_uint()) { - sign = kPositive; + sign_operand = kPositive; } else { - sign = expr_map_[operand].sign_type(); + sign_operand = expr_map_[operand].sign_type(); } - if (sign == SignType::kNegative) { - is_greater = !is_greater; - } else if (sign == SignType::kUnknown) { + if (sign_operand == SignType::kNegative) { + is_greater_ = !is_greater_; + } else if (sign_operand == SignType::kUnknown) { // unable to get the sign of operand - success = false; + success_ = false; return; } - // always use relax bound - bool divided = can_prove(result % operand == 0); - result = result / operand; - // since system will round down when not divided - // eg. 2/4 -> 0; -2/4 -> -1 - // no need fix for !is_greater: - // eg. a <= 2/4 -> a <= 0 - // eg. a <= 0/4 -> a <= 0 - // so just fix for not divided and is_greater - // eg. a >= 2/4 -> a >= 0 + 1 - // eg. a >= 0/4 -> a >= 0 - if (is_greater && !divided) { - result += 1; + bool divided = analyzer_.CanProve(result_ % operand == 0); + + result_ = result_ / operand; + + if (!divided) { + // Handle non-divisible case + // NOTE: this accounts for truc div behavior. + bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); + + if (is_greater_) { + result_ += 1; + } else { + // NOTE: this is a bit sutble hack. + // + // condition: + // - x * operand <= result + // - operand > 0 + // - x >= 0 + // + // Then it is fine to deduce that x <= result / operand. + // - if result > 0, this division round down + // - if result < 0, (result / operand) rounds up and may violate the constraint + // however, given that x is always non-negative, + // it is fine to have this relaxed bound, given that the user of deduce bound + // will respect the bound of x + // + // TODO(tvm-team): think about a better API to incorporate constraint of x. + // e.g. specify an interval of x and return a bound + // that is in the interval and satisfies the condition. + if (target_is_non_neg && sign_operand == kPositive) { + // do nothing + } else { + result_ -= 1; + } + } } - Visit(left ? op->a : op->b); } - Expr result; - bool is_greater{true}; - bool success{true}; + Expr result_; + bool is_greater_{true}; + bool success_{true}; private: void Init(); @@ -180,6 +202,8 @@ class BoundDeducer: public IRVisitor { ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; + // internal analzyer + Analyzer analyzer_; }; class BoundDeduceInputChecker: public IRVisitor { @@ -202,7 +226,7 @@ class BoundDeduceInputChecker: public IRVisitor { void BoundDeducer::Init() { BoundDeduceInputChecker checker; - if (!checker.Check(this)) success = false; + if (!checker.Check(this)) success_ = false; Transform(); } @@ -211,66 +235,65 @@ void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a + 1; + result_ = op->a + 1; } else { // a < b -> a <= b - 1 - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b - 1; + result_ = op->b - 1; } } else if (const LE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b; + result_ = op->b; } } else if (const GT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a - 1; + result_ = op->a - 1; } else { // a > b -> a >= b + 1 - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b + 1; + result_ = op->b + 1; } } else if (const GE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b; + result_ = op->b; } } else { - success = false; + success_ = false; } } void BoundDeducer::Deduce() { Init(); - if (!success) return; + if (!success_) return; Relax(); - if (!success) return; + if (!success_) return; // get the path path_ = GetPath(target_, expr_); if (!path_.size()) { - success = false; + success_ = false; return; } - expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); @@ -278,13 +301,13 @@ void BoundDeducer::Deduce() { void BoundDeducer::Relax() { IntSet a = EvalSet(expr_, relax_map_); - IntSet b = EvalSet(result, relax_map_); + IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { - success = false; + success_ = false; return; } - expr_ = is_greater ? a.min() : a.max(); - result = is_greater ? b.max() : b.min(); + expr_ = is_greater_ ? a.min() : a.max(); + result_ = is_greater_ ? b.max() : b.min(); } IntSet DeduceBound(Expr v, Expr e, @@ -292,12 +315,12 @@ IntSet DeduceBound(Expr v, Expr e, const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); - if (!d.success) return IntSet::nothing(); + if (!d.success_) return IntSet::nothing(); Expr min = neg_inf(), max = pos_inf(); - if (d.is_greater) { - min = d.result; + if (d.is_greater_) { + min = d.result_; } else { - max = d.result; + max = d.result_; } return IntSet::interval(min, max); } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index cc54bff596be7..4fa5fe9bd06a9 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -18,10 +18,8 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file compute_expr.h - * \brief Utility integer expression with quick eager simplification. - * This is weaker than Simplify but can be done Eagerly. + * \brief Utility to invoke certan compute operations. */ #ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_ @@ -41,7 +39,7 @@ namespace arith { * \return The result. */ template -inline Expr ComputeExpr(Expr lhs, Expr rhs) { +inline Expr Compute(Expr lhs, Expr rhs) { return OP::make(lhs, rhs); } @@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) { } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a + b; } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a - b; } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a * b; } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a / b; } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a % b; } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return max(a, b); } template<> -inline Expr ComputeExpr(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return min(a, b); } @@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array& values, Expr empty_value) { } Expr res = values[0]; for (size_t i = 1; i < values.size(); ++i) { - res = ComputeExpr(res, values[i]); + res = Compute(res, values[i]); } return res; } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index ec50aef5c51ed..dc6b80a31c7bd 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -155,9 +155,10 @@ template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, pa->value / pb->value); } if (pa) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index e584c8b1ce332..3c5f12a7379e4 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -27,7 +27,6 @@ #include #include #include -#include "compute_expr.h" namespace tvm { namespace arith { @@ -127,18 +126,18 @@ class LinearEqDetector Expr AddCombine(Expr a, Expr b) { if (!a.defined()) return b; if (!b.defined()) return a; - return ComputeExpr(a, b); + return a + b; } Expr SubCombine(Expr a, Expr b) { // Check b first in case they are both undefined if (!b.defined()) return a; if (!a.defined()) return -b; - return ComputeExpr(a, b); + return a - b; } Expr MulCombine(Expr a, Expr b) { if (!a.defined()) return a; if (!b.defined()) return b; - return ComputeExpr(a, b); + return a * b; } }; diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index bc8666e893b4e..773f6c3a85c40 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -155,7 +155,6 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); - TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), @@ -343,13 +342,16 @@ Mutate_(const Sub* op, const Expr& self) { c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + // Proof in the case of floordiv, need positive condition. + // let x = a * c3 + r + // (x + c1) / c3 - x / c3 => (r + c1) / c3 TVM_TRY_REWRITE_IF((x + c1) / c3 - (x + c2) / c3, - ((x + (c1 % c3)) % c3 + (c1 - c2)) / c3, + ((x + ((c2 % c3) + c3) % c3) % c3 + (c1 - c2)) / c3, CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); TVM_TRY_REWRITE_IF((x + c1) / c3 - x / c3, - ((x + (c1 % c3)) % c3 + c1) / c3, + (x % c3 + c1) / c3, CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); @@ -1192,9 +1194,9 @@ Mutate_(const LT* op, const Expr& self) { TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y); TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y); - + TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); + TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); - TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1); } return ret; } diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 01cb96ee1323e..fc6b92a87ce1d 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -28,15 +28,27 @@ #include #include #include -#include "arithmetic/Simplify.h" namespace tvm { namespace arith { -// statement simplifier + using namespace ir; class StmtSimplifier : public IRMutator { public: + using IRMutator::Mutate; + + Expr Mutate(Expr expr) final { + return analyzer_.Simplify(expr); + } + + Stmt Simplify(Stmt stmt, Map vrange) { + for (auto kv : vrange) { + analyzer_.Bind(kv.first, kv.second); + } + return Mutate(stmt); + } + Stmt Mutate_(const For* op, const Stmt& s) final { Var loop_var(op->loop_var.node_); analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); @@ -125,28 +137,12 @@ class StmtSimplifier : public IRMutator { std::unordered_map var_dom_; }; - -class CanonicalStmtSimplifier : public StmtSimplifier { - public: - using StmtSimplifier::Mutate; - Expr Mutate(Expr expr) final { - return analyzer_.canonical_simplify(expr); - } - - Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - for (auto kv : vrange) { - analyzer_.Bind(kv.first, kv.second); - } - return Mutate(stmt); - } -}; - } // namespace arith namespace ir { Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - return arith::CanonicalStmtSimplifier().CanonicalSimplify( + return arith::StmtSimplifier().Simplify( stmt, vrange); } @@ -158,42 +154,18 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; +Expr Simplify(Expr expr, Map vrange) { + arith::Analyzer analyzer; for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - - -Expr Simplify(Expr a, Map vrange) { - // Simplify top level reduce. - if (const Reduce* r = a.as()) { - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } + analyzer.Bind(kv.first, kv.second); } - return Simplify_(a, vrange); + expr = analyzer.Simplify(expr); + return expr; } -Stmt Simplify(Stmt a, Map vrange) { - return Simplify_(a, vrange); +Stmt Simplify(Stmt stmt, Map vrange) { + return arith::StmtSimplifier().Simplify( + stmt, vrange); } } // namespace ir } // namespace tvm diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 6917200ff9205..c1622338174df 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -342,7 +342,7 @@ Buffer BufferWithOffsetAlignment(Array shape, } return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor); + data_alignment, offset_factor, kDefault); } void GetBinds(const Array& args, diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 22dde1c463892..a32473158bd5f 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,7 +27,6 @@ #include #include #include "codegen_cuda.h" -#include "../arithmetic/compute_expr.h" namespace tvm { namespace codegen { diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 1e56583a37fd0..fde0486483b2b 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e, std::function f) { if (const Ramp* ramp = e.as()) { for (int i = 0; i < ramp->type.lanes(); ++i) { - Expr offset = arith::ComputeExpr( - ramp->base, - arith::ComputeExpr(ramp->stride, i)); + Expr offset = ramp->base + (ramp->stride * i); f(i, MakeValue(offset)); } } else { diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index fd113ca4614a2..7686250c5ce57 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -25,8 +25,8 @@ #include #include #include -#include "../../arithmetic/compute_expr.h" #include "codegen_spirv.h" +#include "../../arithmetic/compute_expr.h" namespace tvm { namespace codegen { @@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) { spirv::Value v = base; if (i != 0) { spirv::Value offset = MakeValue( - arith::ComputeExpr(make_const(op->stride.type(), i), op->stride)); + make_const(op->stride.type(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e, std::function f) { if (const Ramp* ramp = e.as()) { for (int i = 0; i < ramp->type.lanes(); ++i) { - Expr offset = arith::ComputeExpr( - ramp->base, - arith::ComputeExpr(ramp->stride, i)); + Expr offset = ramp->base + ramp->stride * i; f(i, MakeValue(offset)); } } else { diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 8c584c50b3c67..cb5c86710fabb 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include #include "../arithmetic/compute_expr.h" namespace tvm { @@ -48,7 +49,8 @@ Buffer decl_buffer(Array shape, Expr(), name, "", - 0, 0); + 0, 0, + kDefault); } // Split the given expression w.r.t the add operator @@ -364,7 +366,8 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const { n->name + "_slice", n->scope, n->data_alignment, - 0); + 0, + n->buffer_type); } Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr offset) const { @@ -375,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr extent = make_const(self->DefaultIndexType(), 1); } else if (self->strides.size() == self->shape.size()) { int highest_dim = 0; - extent = arith::ComputeExpr( - self->strides[highest_dim], self->shape[highest_dim]) - offset; + extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { extent = arith::ComputeReduce(self->shape, Expr()) - offset; } @@ -404,7 +406,8 @@ Buffer BufferNode::make(Var data, std::string name, std::string scope, int data_alignment, - int offset_factor) { + int offset_factor, + BufferType buffer_type) { auto n = make_node(); n->data = std::move(data); n->dtype = dtype; @@ -427,6 +430,12 @@ Buffer BufferNode::make(Var data, n->elem_offset = std::move(elem_offset); n->data_alignment = data_alignment; n->offset_factor = offset_factor; + n->buffer_type = buffer_type; + if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { + for (size_t i = 0; i < n->shape.size(); ++i) { + n->strides.push_back(tvm::var("stride")); + } + } return Buffer(n); } diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 42b1331e3736c..78f8c82d97dbf 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -80,7 +80,7 @@ Operation ScanOpNode::make(std::string name, for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype); - CHECK(can_prove(init[i]->shape[0] == axis->dom->min)) + CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; CHECK(prove_equal( state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 2822393d3f75e..ff4c77accf073 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file arg_binder.cc * \brief Helper utility to match and bind arguments. */ @@ -242,6 +241,21 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, check = IfThenElse::make(Not::make(is_null), check, Stmt()); init_nest_.emplace_back(Block::make(check, Evaluate::make(0))); } + } else if (buffer->buffer_type == kAutoBroadcast) { + Type stype = buffer->DefaultIndexType(); + Expr stride = make_const(stype, 1); + for (size_t i = buffer->shape.size(); i != 0; --i) { + size_t k = i - 1; + std::ostringstream field_name; + field_name << v_strides->name_hint << '[' << k << ']'; + Expr value = cast(buffer->shape[k].type(), + Load::make(tvm_shape_type, v_strides, + IntImm::make(Int(32), k), const_true(1))); + value = tvm::if_then_else(is_null, stride, value); + value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); + Bind_(buffer->strides[k], value, field_name.str(), true); + stride = Simplify(stride * buffer->shape[k]); + } } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index a906ee3e54741..8df5fe1f77572 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -160,7 +160,7 @@ class CopyIntrinInjector : public IRMutator { store_strides[loop_var_size], store->buffer_var->name_hint, GetStorageScope(store->buffer_var.get()), - 0, 0); + 0, 0, kDefault); Buffer src = BufferNode::make( Var(load->buffer_var.node_), load->type, @@ -169,7 +169,7 @@ class CopyIntrinInjector : public IRMutator { src_elem_offset, load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), - 0, 0); + 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 94b4ab3cb4c93..027639caf7028 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include #include +#include #include "ir_util.h" #include "../arithmetic/compute_expr.h" @@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator { Stmt Mutate_(const Allocate* op, const Stmt& s) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce - (op->extents, Expr()) * op->type.lanes(); + it->second.stride = arith::ComputeReduce( + op->extents, Expr()) * op->type.lanes(); Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); Array new_extents{make_const(op->extents[0].type(), 2)}; @@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator { << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); Expr zero = old_loop->min; - Expr new_ext = arith::ComputeExpr( - old_loop->extent, make_const(old_loop->loop_var.type(), 1)); + Expr new_ext = + old_loop->extent - make_const(old_loop->loop_var.type(), 1); Expr factor = make_const(new_ext.type(), split_loop_); - Expr outer_ext = arith::ComputeExpr
(new_ext, factor); - Expr tail_base = arith::ComputeExpr(outer_ext, factor); + Expr outer_ext = new_ext / factor; + Expr tail_base = outer_ext * factor; Var outer_var(old_loop->loop_var->name_hint + ".outer", old_loop->loop_var.type()); std::unordered_map vmap; std::vector loop_seq; diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 9009416192e08..88e7f4370126f 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file inject_virtual_thread.cc */ #include @@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor { explicit ExprTouched(const std::unordered_set &touched, bool check_write) : touched_var_(touched), check_write_(check_write) {} + void Visit(const NodeRef& n) final { // early stopping if (expr_touched_ && !check_write_) return; @@ -241,8 +241,8 @@ class VTInjector : public IRMutator { visit_touched_var_ = true; Expr offset = Mutate(op->args[2]); Expr extent = Mutate(op->args[3]); - Expr stride = arith::ComputeExpr
( - it->second, make_const(offset.type(), dtype.lanes())); + Expr stride = + it->second / make_const(offset.type(), dtype.lanes()); offset = stride * var_ + offset; return Call::make( op->type, op->name, diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 0a5b7410f3cff..33dbaed83b697 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -466,8 +466,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Stmt body, bool partition_thread_scope) { using namespace arith; + // include hint of var. + hint_map_.insert({var.get(), IntSet::interval(min, max)}); + PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); + + hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); @@ -504,9 +509,9 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); - if (!can_prove(body_begin == min)) { + if (!analyzer_.CanProve(body_begin == min)) { Expr cond = (body_begin - min >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; body_begin = Max::make(body_begin, min); @@ -529,10 +534,10 @@ Stmt LoopPartitioner::TryPartition(const Node* node, bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); - if (!can_prove(middle_interval.max() == max)) { + if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative Expr cond = (max - post_doubt_begin + 1 >= 0); - if (!can_prove(cond)) { + if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; post_doubt_begin = Min::make(post_doubt_begin, max); @@ -554,7 +559,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; - if (!can_prove(body_begin >= post_doubt_begin)) { + if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); @@ -576,8 +581,8 @@ Stmt LoopPartitioner::TryPartition(const Node* node, s = AppendStmts(s, post_stmt); } else { Expr cond = const_true(); - if (!can_prove(body_begin == min)) cond = cond && (var >= body_begin); - if (!can_prove(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); + if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); } s = ConvertSSA(s); @@ -587,7 +592,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, inline Stmt LoopPartitioner::MakeFor(const Node *node, Expr extent, Stmt body) { const For *for_node = static_cast(node); CHECK(for_node); - if (can_prove(extent == make_const(Int(32), 1))) { + if (analyzer_.CanProve(extent == make_const(Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(Int(32), 0)}}); } else { diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index d0490b2152a0d..02c72d03fea89 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 7d9d48600f715..bb7260fb5ddec 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -18,8 +18,6 @@ */ /*! - * Copyright (c) 2018 by Contributors - * * Lower warp memory to use local memory * and shuffle intrinsics. * diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 13f46ecb6f7a2..0109ad19d7a67 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,7 +33,6 @@ #include "ir_util.h" #include "arg_binder.h" -#include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { diff --git a/src/pass/narrow_channel_access.cc b/src/pass/narrow_channel_access.cc index 731064edb0121..57f3baf20e108 100644 --- a/src/pass/narrow_channel_access.cc +++ b/src/pass/narrow_channel_access.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -200,7 +200,7 @@ class ChannelAccessRewriter : public IRMutator { Expr base = linear_eq[1]; if (!is_zero(base)) return body; Expr left = ir::Simplify(adv_op->value - coeff * for_op->extent); - if (!can_prove(left >= 0)) return body; + if (!analyzer_.CanProve(left >= 0)) return body; // rewrite access index. ChannelAccessIndexRewriter rw( ch->handle_var.get(), var * coeff, read_access); @@ -233,6 +233,7 @@ class ChannelAccessRewriter : public IRMutator { return body; } + arith::Analyzer analyzer_; std::vector tasks_; }; diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 215f6d7397323..19e7a32e4acf5 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator { stride = ir::Simplify(stride); } rstrides.push_back(stride); - stride = arith::ComputeExpr(stride, shape[dim]); + stride = stride * shape[dim]; } strides = Array(rstrides.rbegin(), rstrides.rend()); } @@ -220,7 +220,7 @@ class StorageFlattener : public IRMutator { Var(key.GetName(), Handle()), op->type, shape, strides, Expr(), key.GetName(), skey.to_string(), - align, 0); + align, 0, kDefault); buf_map_[key] = e; Stmt body = this->Mutate(op->body); @@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator { int first_dim = 0; ret = Allocate::make( e.buffer->data, storage_type, - {arith::ComputeExpr(e.buffer->strides[first_dim], e.buffer->shape[first_dim])}, + {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, make_const(Bool(e.buffer->dtype.lanes()), true), body); } else { shape = e.buffer->shape; @@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator { if (be.bounds.size() != 0) { CHECK_EQ(tuple->args.size(), be.bounds.size() * 2); for (size_t i = 0; i < be.buffer->shape.size(); ++i) { - begins.push_back( - arith::ComputeExpr(tuple->args[2 * i], be.bounds[i]->min)); + begins.push_back(tuple->args[2 * i] - be.bounds[i]->min); extents.push_back(tuple->args[2 * i + 1]); } } else { diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 806a80ad4dc90..eba1cee8b7c70 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -606,7 +606,7 @@ class StoragePlanRewriter : public IRMutator { } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); - bool divided = can_prove(combo_size % type_bits == 0); + bool divided = analyzer_.CanProve(combo_size % type_bits == 0); combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { @@ -920,6 +920,8 @@ class StoragePlanRewriter : public IRMutator { std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; + // analyzer + arith::Analyzer analyzer_; }; // Turn alloc into vector alloc diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index ead234e2c4a08..756130886e13f 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2017 by Contributors * Loop unrolling as in Halide pipeline. * \file unroll_loop.cc */ @@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator { } Stmt Unroll(const For* op) { - using arith::ComputeExpr; int value = GetExtent(op); // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; @@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator { Stmt unrolled; for (int i = 0; i < value; ++i) { Var lv(op->loop_var.node_); - vmap.Set(lv, - ComputeExpr( - op->min, make_const(op->loop_var.type(), i))); + vmap.Set(lv, op->min + make_const(op->loop_var.type(), i)); Stmt step = Substitute(body, vmap); if (unrolled.defined()) { unrolled = Block::make(unrolled, step); diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 8c3d383c1529a..2d8416e9a9de7 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -18,13 +18,13 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. #include #include #include +#include #include #include #include @@ -132,11 +132,11 @@ class Vectorizer : public IRMutator { if (lanes != 1) { const Ramp* b_ramp = b.as(); const Ramp* a_ramp = a.as(); - if (a_ramp && b.type().lanes() == 1 && can_prove(b > 0)) { + if (a_ramp && b.type().lanes() == 1 && analyzer_.CanProve(b > 0)) { return Ramp::make( a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } - if (b_ramp && a.type().lanes() == 1 && can_prove(a > 0)) { + if (b_ramp && a.type().lanes() == 1 && analyzer_.CanProve(a > 0)) { return Ramp::make( b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } @@ -186,7 +186,7 @@ class Vectorizer : public IRMutator { Expr stride = this->Mutate(op->stride); if (base.type().lanes() > 1 && stride.type().lanes() == 1) { const Ramp* base_ramp = base.as(); - if (can_prove(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { + if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.type(), op->lanes))) { return Ramp::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } @@ -423,6 +423,8 @@ class Vectorizer : public IRMutator { } private: + // analyzer + arith::Analyzer analyzer_; // variable to be replaced Var var_; // the lanes. @@ -483,13 +485,13 @@ class Vectorizer : public IRMutator { const Ramp* a_ramp = a.as(); if (a.type().lanes() == 1 && b_ramp) { return Ramp::make( - arith::ComputeExpr(a, b_ramp->base), - arith::ComputeExpr(make_zero(b_ramp->stride.type()), b_ramp->stride), + arith::Compute(a, b_ramp->base), + arith::Compute(make_zero(b_ramp->stride.type()), b_ramp->stride), b_ramp->lanes); } if (b.type().lanes() == 1 && a_ramp) { return Ramp::make( - arith::ComputeExpr(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 3feb7e4a4b543..7de77c8bcfd4b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -21,6 +21,7 @@ * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ +#include #include #include #include @@ -433,7 +434,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module = Optimize(relay_module, targets_, params); CHECK(relay_module.defined()); // Get the updated function. - func = relay_module->Lookup(relay_module->entry_func->name_hint); + func = relay_module->Lookup("main"); // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 7ae1befcfe895..83e4a36ff4f93 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 9b510ad2fd293..9765cf90da18a 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,8 +27,9 @@ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #include +#include #include -#include +#include #include #include diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 5c2e5c4c289a1..91a597baceaf3 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -25,7 +25,7 @@ */ #include #include -#include +#include #include "../../common/arena.h" namespace tvm { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ff2d9e6117abb..913d7addea4d2 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include "compile_engine.h" @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int tag, +ConstructorValue ConstructorValueNode::make(int32_t tag, tvm::Array fields, Constructor constructor) { NodePtr n = make_node(); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 65a7efd4c2051..139dab21e973d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 668c024a8d550..6290ef7c6e932 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc index cf0b952005fcb..2f656c8cef992 100644 --- a/src/relay/backend/vm/vm.cc +++ b/src/relay/backend/vm/vm.cc @@ -28,17 +28,18 @@ #include #include #include -#include +#include namespace tvm { namespace relay { namespace vm { +runtime::vm::VirtualMachine CompileModule(const Module& mod); + using tvm::runtime::Object; using tvm::runtime::ObjectTag; using tvm::runtime::vm::VirtualMachine; - VirtualMachine FromModule(const Module& module, const std::vector& ctxs) { auto vm = CompileModule(module); vm.Init(ctxs); @@ -51,10 +52,10 @@ Object EvaluateModule(const Module& module, const std::vector ctxs, // TODO(zhiics): This measurement is for temporary usage. Remove it later. We // need to introduce a better profiling method. #if ENABLE_PROFILING - DLOG(INFO) << "Entry function is " << module->entry_func << std::endl; + DLOG(INFO) << "Entry function is main." << std::endl; auto start = std::chrono::high_resolution_clock::now(); #endif // ENABLE_PROFILING - Object res = vm.Invoke(module->entry_func->name_hint, vm_args); + Object res = vm.Invoke("main", vm_args); #if ENABLE_PROFILING auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start).count(); diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index b59281a4f1fd9..3eb1d99f5a889 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 81017d4fddfa6..42e66261a5533 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include "type_functor.h" #include "../../lang/attr_functor.h" diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e09d790822274..0434e2ac59c64 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file src/tvm/relay/expr_mutator.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * @@ -26,6 +26,7 @@ * the cost of using functional updates. */ #include +#include #include "type_functor.h" namespace tvm { @@ -345,7 +346,7 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_API("relay._ir_pass.post_order_visit") +TVM_REGISTER_API("relay._analysis.post_order_visit") .set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); @@ -353,7 +354,7 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit") }); // Implement bind. -class ExprBinder : public ExprMutator { +class ExprBinder : public ExprMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) { @@ -383,13 +384,26 @@ class ExprBinder : public ExprMutator { } } + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Var VisitVar(const Var& v) final { + return Downcast(VisitExpr(v)); + } + private: const tvm::Map& args_map_; }; Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { - Expr new_body = ExprBinder(args_map).Mutate(func->body); + Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; for (Var param : func->params) { if (!args_map.count(param)) { @@ -406,7 +420,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { func->type_params, func->attrs); } else { - return ExprBinder(args_map).Mutate(expr); + return ExprBinder(args_map).VisitExpr(expr); } } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index c57475476e589..6039ba272ddc1 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include "type_functor.h" #include "../../lang/attr_functor.h" @@ -412,12 +412,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { return RelayHashHandler().ExprHash(expr); } -TVM_REGISTER_API("relay._ir_pass._expr_hash") +TVM_REGISTER_API("relay._analysis._expr_hash") .set_body_typed([](NodeRef ref) { return static_cast(RelayHashHandler().Hash(ref)); }); -TVM_REGISTER_API("relay._ir_pass._type_hash") +TVM_REGISTER_API("relay._analysis._type_hash") .set_body_typed([](Type type) { return static_cast(RelayHashHandler().TypeHash(type)); }); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 58f614a3cc77c..0ad0a91efd217 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -23,7 +23,8 @@ * \brief The global module in Relay. */ #include -#include +#include +#include #include namespace tvm { @@ -45,18 +46,21 @@ Module ModuleNode::make(tvm::Map global_funcs, n->global_var_map_.Set(kv.first->name_hint, kv.first); } - n->entry_func = GlobalVarNode::make("main"); - for (const auto& kv : n->type_definitions) { // set global typevar map CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) << "Duplicate global type definition name " << kv.first->var->name_hint; n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + n->RegisterConstructors(kv.first, kv.second); } return Module(n); } +bool ModuleNode::ContainGlobalVar(const std::string& name) const { + return global_var_map_.find(name) != global_var_map_.end(); +} + GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) @@ -88,8 +92,9 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { } void ModuleNode::Add(const GlobalVar& var, - const Function& func, + const Function& f, bool update) { + Function func = Downcast(DeDup(f)); // Type check the item before we add it to the module. auto mod = GetRef(this); Function checked_func = InferType(func, mod, var); @@ -106,15 +111,25 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { + // We hash the global type var name to use as a globally unique prefix for tags. + // The hash will be used as the most significant byte of the tag, with the index of + // the constructor in the less significant bytes + size_t hash = std::hash()(var->var->name_hint); + int32_t prefix = static_cast(hash & 0xff) << 24; + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = prefix | static_cast(i); + constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; + } +} + void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) << "Duplicate global type definition name " << var->var->name_hint; global_type_var_map_.Set(var->var->name_hint, var); - for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = i; - } + RegisterConstructors(var, type); // need to kind check at the end because the check can look up // a definition potentially @@ -157,6 +172,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { return this->LookupDef(id); } +Constructor ModuleNode::LookupTag(const int32_t tag) { + auto it = constructor_tag_map_.find(tag); + CHECK(it != constructor_tag_map_.end()) + << "There is no constructor with the tag " << tag; + return (*it).second; +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -174,7 +196,8 @@ Module ModuleNode::FromExpr( } else { func = FunctionNode::make({}, expr, Type(), {}, {}); } - mod->Add(mod->entry_func, func); + auto main_gv = GlobalVarNode::make("main"); + mod->Add(main_gv, func); return mod; } @@ -183,8 +206,27 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") .set_body_typed(ModuleNode::make); -TVM_REGISTER_API("relay._make.Module_Add") -.set_body_method(&ModuleNode::Add); +TVM_REGISTER_API("relay._module.Module_Add") +.set_body([](TVMArgs args, TVMRetValue* ret) { + Module mod = args[0]; + GlobalVar var = args[1]; + NodeRef val = args[2]; + bool update = args[3]; + CHECK(val->derived_from()); + if (val->derived_from()) { + mod->Add(var, Downcast(val), update); + } else if (val->derived_from()) { + GlobalVar gv = Downcast(val); + auto mod_copy = Module(make_node(*mod.operator->())); + mod_copy = transform::EtaExpand()(mod_copy); + auto func = mod_copy->Lookup(gv->name_hint); + mod->Add(var, Downcast(func), update); + } else { + auto func = FunctionNode::make({}, Downcast(val), Type(nullptr), {}); + mod->Add(var, func, update); + } + *ret = mod; +}); TVM_REGISTER_API("relay._module.Module_AddDef") .set_body_method(&ModuleNode::AddDef); @@ -192,44 +234,52 @@ TVM_REGISTER_API("relay._module.Module_AddDef") TVM_REGISTER_API("relay._module.Module_GetGlobalVar") .set_body_method(&ModuleNode::GetGlobalVar); +TVM_REGISTER_API("relay._module.Module_ContainGlobalVar") +.set_body_method(&ModuleNode::ContainGlobalVar); + TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") .set_body_method(&ModuleNode::GetGlobalTypeVar); TVM_REGISTER_API("relay._module.Module_Lookup") .set_body_typed([](Module mod, GlobalVar var) { - return mod->Lookup(var); - }); + return mod->Lookup(var); +}); TVM_REGISTER_API("relay._module.Module_Lookup_str") .set_body_typed([](Module mod, std::string var) { - return mod->Lookup(var); - }); + return mod->Lookup(var); +}); TVM_REGISTER_API("relay._module.Module_LookupDef") .set_body_typed([](Module mod, GlobalTypeVar var) { - return mod->LookupDef(var); - }); + return mod->LookupDef(var); +}); TVM_REGISTER_API("relay._module.Module_LookupDef_str") .set_body_typed([](Module mod, std::string var) { - return mod->LookupDef(var); + return mod->LookupDef(var); +}); + +TVM_REGISTER_API("relay._module.Module_LookupTag") +.set_body_typed([](Module mod, int32_t tag) { + return mod->LookupTag(tag); }); TVM_REGISTER_API("relay._module.Module_FromExpr") .set_body_typed([](Expr e) { - return ModuleNode::FromExpr(e); + return ModuleNode::FromExpr(e); }); TVM_REGISTER_API("relay._module.Module_Update") .set_body_typed([](Module mod, Module from) { - mod->Update(from); - }); + mod->Update(from); +}); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch( - [](const ModuleNode *node, tvm::IRPrinter *p) { - p->stream << "ModuleNode( " << node->functions << ")"; - }); + [](const ModuleNode *node, tvm::IRPrinter *p) { + p->stream << "ModuleNode( " << node->functions << ")"; +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 7a61079204edc..39fc36fba4baf 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -645,11 +645,21 @@ class PrettyPrinter : Doc VisitType_(const FuncTypeNode* node) final { Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "<"; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); + } + doc << PrintVec(type_params); + doc << ">"; + } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "fn (" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RefTypeNode* node) final { diff --git a/src/relay/ir/type_functor.cc b/src/relay/ir/type_functor.cc index 9fca2e0326859..cde68c50daeff 100644 --- a/src/relay/ir/type_functor.cc +++ b/src/relay/ir/type_functor.cc @@ -92,6 +92,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } +Type TypeMutator::VisitType(const Type& t) { + return t.defined() ? TypeFunctor::VisitType(t) : t; +} + // Type Mutator. Array TypeMutator::MutateArray(Array arr) { // The array will do copy on write diff --git a/src/relay/ir/type_functor.h b/src/relay/ir/type_functor.h index 27ac288fe48db..c3ee14eedd487 100644 --- a/src/relay/ir/type_functor.h +++ b/src/relay/ir/type_functor.h @@ -139,6 +139,7 @@ class TypeVisitor : public TypeFunctor { // Mutator that transform a type to another one. class TypeMutator : public TypeFunctor { public: + Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; Type VisitType_(const TensorTypeNode* op) override; Type VisitType_(const IncompleteTypeNode* op) override; diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index cc71968fba585..82424500ffc8e 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -24,7 +24,8 @@ other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation. */ -#include +#include +#include #include #include #include @@ -348,9 +349,6 @@ Expr AlterOpLayout(const Expr& expr) { return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); } -TVM_REGISTER_API("relay._ir_pass.AlterOpLayout") -.set_body_typed(AlterOpLayout); - } // namespace alter_op_layout namespace transform { diff --git a/src/relay/pass/canonicalize_cast.cc b/src/relay/pass/canonicalize_cast.cc index 99f4a7f44e7e7..04fec248f81c9 100644 --- a/src/relay/pass/canonicalize_cast.cc +++ b/src/relay/pass/canonicalize_cast.cc @@ -22,7 +22,7 @@ * \file canonicalize_cast.cc * \brief Canonicalize cast expressions to make operator fusion more efficient. */ -#include +#include #include #include #include diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index ff9e2304a3bc3..fc0c43d200e5d 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -23,7 +23,7 @@ * \brief Canonicalize special operators to basic operators. This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ -#include +#include #include #include #include @@ -61,9 +61,6 @@ Expr CanonicalizeOps(const Expr& e) { return BiasAddSimplifier().Mutate(e); } -TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") -.set_body_typed(CanonicalizeOps); - namespace transform { Pass CanonicalizeOps() { diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index c95c1ddf8e160..d72705c8ce470 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -33,7 +33,7 @@ * convolution branches, such as Inception block. */ -#include +#include #include #include #include @@ -355,9 +355,6 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { return ParallelConv2DCombiner(min_num_branches).Combine(expr); } -TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") -.set_body_typed(CombineParallelConv2D); - namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc new file mode 100644 index 0000000000000..d5d4f69606539 --- /dev/null +++ b/src/relay/pass/de_duplicate.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file de_duplicate.cc + * \brief Use a fresh Id for every Var to make the result well-formed. + */ + +#include +#include +#include +#include "../ir/type_functor.h" + +namespace tvm { +namespace relay { + +Expr DeDup(const Expr& e) { + class DeDupMutator : public TypeMutator, + public ExprMutator, + public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Var Fresh(const Var& v) { + Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation)); + rename_[v] = ret; + return ret; + } + + Expr VisitExpr(const Expr& e) final { + return ExprMutator::VisitExpr(e); + } + + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return rename_.count(v) != 0 ? rename_.at(v) : v; + } + + Expr VisitExpr_(const LetNode* op) final { + Var v = Fresh(op->var); + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } + + Type VisitType(const Type& t) final { + return t.defined() ? TypeMutator::VisitType(t) : t; + } + + Expr VisitExpr_(const FunctionNode* op) final { + tvm::Array type_params; + for (const TypeVar& type_param : op->type_params) { + type_params.push_back(Fresh(type_param)); + } + tvm::Array params; + for (const Var& param : op->params) { + params.push_back(Fresh(param)); + } + return FunctionNode::make(params, + VisitExpr(op->body), + VisitType(op->ret_type), + type_params, + op->attrs); + } + + Pattern VisitPattern(const Pattern& p) final { + return PatternMutator::VisitPattern(p); + } + + Pattern VisitPattern_(const PatternVarNode* op) final { + return PatternVarNode::make(Fresh(op->var)); + } + + Clause VisitClause(const Clause& c) final { + Pattern pat = VisitPattern(c->lhs); + return ClauseNode::make(pat, VisitExpr(c->rhs)); + } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; + } + + Var VisitVar(const Var& v) final { + return Fresh(v); + } + + private: + std::unordered_map rename_; + std::unordered_map type_rename_; + }; + + Expr ret = DeDupMutator().VisitExpr(e); + CHECK_EQ(FreeVars(ret).size(), FreeVars(e).size()); + return ret; +} + +TVM_REGISTER_API("relay._transform.dedup") +.set_body_typed(DeDup); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 7e186f80df929..54075f0699e6f 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -28,8 +28,9 @@ * CalcDep turn an expr into a dependency graph of expr, * GenLet turn the dependency graph into a let list, taking only the used value. */ -#include +#include #include +#include #include "let_list.h" namespace tvm { @@ -156,9 +157,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { return CalcDep::Eliminate(e, inline_once); } -TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") -.set_body_typed(DeadCodeElimination); - namespace transform { Pass DeadCodeElimination(bool inline_once) { diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h index 7f53918ebcb7f..5e2b08c352f09 100644 --- a/src/relay/pass/dependency_graph.h +++ b/src/relay/pass/dependency_graph.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors. * \file tvm/relay/pass/dependency_graph.h - * \brief + * \brief create a dependency graph. */ #ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ #define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 8eeb493f1feba..aec974b184d3f 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -34,7 +34,6 @@ #include #include #include -#include #include #include @@ -559,13 +558,13 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo") +TVM_REGISTER_API("relay._analysis.CollectDeviceInfo") .set_body_typed(CollectDeviceInfo); -TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") +TVM_REGISTER_API("relay._analysis.RewriteDeviceAnnotation") .set_body_typed(RewriteAnnotatedOps); -TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") +TVM_REGISTER_API("relay._analysis.CollectDeviceAnnotationOps") .set_body_typed(CollectDeviceAnnotationOps); namespace transform { diff --git a/src/relay/pass/eliminate_common_subexpr.cc b/src/relay/pass/eliminate_common_subexpr.cc index 883681adcaf45..33a791b2bd996 100644 --- a/src/relay/pass/eliminate_common_subexpr.cc +++ b/src/relay/pass/eliminate_common_subexpr.cc @@ -27,7 +27,7 @@ * to replace an expression with a previously appeared expression with the same input and * attributes. The fskip callback argument allows us to skip specific expressions. */ -#include +#include #include #include #include @@ -85,9 +85,6 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { return CommonSubexprEliminator(callback)(expr); } -TVM_REGISTER_API("relay._ir_pass.eliminate_common_subexpr") -.set_body_typed(EliminateCommonSubexpr); - namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 3139d41d63937..e73e3778395e9 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -25,7 +25,8 @@ * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x). * */ -#include +#include +#include namespace tvm { namespace relay { @@ -44,10 +45,8 @@ Expr EtaExpand(const Expr& e, const Module& mod) { original_type_params = func->type_params; ret_type = func->ret_type; } else { - auto inferred = InferType(e, mod); - CHECK(inferred->is_type()); - - auto func = GetRef(inferred.as_derived()); + CHECK(e->is_type()); + auto func = GetRef(e.as_derived()); original_params = func->params; original_type_params = func->type_params; ret_type = func->ret_type; @@ -62,19 +61,18 @@ Expr EtaExpand(const Expr& e, const Module& mod) { auto new_func = FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params); - return InferType(new_func, mod); + return new_func; } -TVM_REGISTER_API("relay._ir_pass.eta_expand").set_body_typed(EtaExpand); - namespace transform { Pass EtaExpand() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(EtaExpand(f, m)); - }; - return CreateFunctionPass(pass_func, 1, "EtaExpand", {}); + return Downcast(EtaExpand(f, m)); + }; + Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {}); + return Sequential({expanded, InferType()}); } TVM_REGISTER_API("relay._transform.EtaExpand") diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index e86ca06211126..df3a5d7ecec52 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -23,7 +23,7 @@ * \brief Detect features used in Expr/Module */ #include -#include +#include #include #include #include @@ -97,7 +97,7 @@ Array PyDetectFeature(const Expr& expr, const Module& mod) { return static_cast>(fs); } -TVM_REGISTER_API("relay._ir_pass.detect_feature") +TVM_REGISTER_API("relay._analysis.detect_feature") .set_body_typed(PyDetectFeature); } // namespace relay diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 815407038b082..7b896a8d0f7fe 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -21,7 +21,7 @@ * Copyright (c) 2018 by Contributors * \file constant_folding.cc */ -#include +#include #include #include #include @@ -156,9 +156,13 @@ class ConstantFolder : public ExprMutator { } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { - expr = InferType(expr, Module(nullptr)); - expr = FuseOps(expr, 0, Module(nullptr)); - expr = InferType(expr, Module(nullptr)); + std::vector passes = {transform::FuseOps(0), + transform::InferType()}; + auto mod = ModuleNode::FromExpr(expr); + auto seq = transform::Sequential(passes); + mod = seq(mod); + auto entry_func = mod->Lookup("main"); + expr = expr.as() == nullptr ? entry_func->body : entry_func; return ValueToExpr(executor_(expr)); } // Evaluate shape_of op @@ -213,9 +217,6 @@ Expr FoldConstant(const Expr& expr) { Module(nullptr), ctx, target)).Mutate(expr); } -TVM_REGISTER_API("relay._ir_pass.FoldConstant") -.set_body_typed(FoldConstant); - namespace transform { Pass FoldConstant() { diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 53089807ace5f..868a08f8b5769 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -26,7 +26,7 @@ * conv/dense operators. */ #include -#include +#include #include #include #include @@ -545,10 +545,6 @@ Expr ForwardFoldScaleAxis(const Expr& data) { data, "FScaleAxisForwardRewrite", fcontext); } -// Expose the FoldScaleAxisFoward -TVM_REGISTER_API("relay._ir_pass.forward_fold_scale_axis") -.set_body_typed(ForwardFoldScaleAxis); - //---------------------------------------- // Implement backward transformations. //---------------------------------------- @@ -947,9 +943,6 @@ Expr BackwardFoldScaleAxis(const Expr& data) { return make_node()->Fold(data); } -TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis") -.set_body_typed(BackwardFoldScaleAxis); - } // namespace fold_scale_axis namespace transform { @@ -964,6 +957,9 @@ Pass ForwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.ForwardFoldScaleAxis") +.set_body_typed(ForwardFoldScaleAxis); + Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { @@ -974,6 +970,9 @@ Pass BackwardFoldScaleAxis() { {ir::StringImm::make("InferType")}); } +TVM_REGISTER_API("relay._transform.BackwardFoldScaleAxis") +.set_body_typed(BackwardFoldScaleAxis); + Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc index 8ad61270e33a8..6c66d6e982a71 100644 --- a/src/relay/pass/forward_rewrite.cc +++ b/src/relay/pass/forward_rewrite.cc @@ -23,9 +23,9 @@ * \file forward_rewrite.cc * \brief Apply rewriting rules in a forward fashion. */ -#include #include #include +#include #include "pass_util.h" namespace tvm { @@ -206,37 +206,5 @@ Expr ForwardRewrite(const Expr& expr, return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } -namespace transform { - -using std::function; - -Pass ForwardRewrite(const std::string& rewrite_map_attr_name, - function fcontext, - function fmulti_ref_trigger) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ForwardRewrite(f, - rewrite_map_attr_name, - fcontext, - fmulti_ref_trigger)); - }; - return CreateFunctionPass(pass_func, 1, "ForwardRewrite", {}); -} - -Pass ForwardRewrite(const FForwardRewrite& rewrite_func, - function fcontext, - function fmulti_ref_trigger) { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ForwardRewrite(f, - rewrite_func, - fcontext, - fmulti_ref_trigger)); - }; - return CreateFunctionPass(pass_func, 1, "ForwardRewriteFunc", {}); -} - -} // namespace transform - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 9f940e54953b9..cdd2837463659 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -26,7 +26,7 @@ * Fuse necessary ops into a single one. */ #include -#include +#include #include #include #include @@ -963,9 +963,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) { } } -TVM_REGISTER_API("relay._ir_pass.FuseOps") -.set_body_typed(FuseOps); - namespace transform { Pass FuseOps(int fuse_opt_level) { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 5d26f7adcff77..1abe7a94b621f 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -26,7 +26,8 @@ #include #include #include -#include +#include +#include #include "pattern_util.h" #include "let_list.h" #include "../ir/type_functor.h" @@ -246,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._ir_pass.first_order_gradient") +TVM_REGISTER_API("relay._analysis.first_order_gradient") .set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { @@ -351,7 +352,7 @@ Expr Gradient(const Expr& re, const Module& mod) { return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_API("relay._ir_pass.gradient") +TVM_REGISTER_API("relay._transform.gradient") .set_body_typed(Gradient); } // namespace relay diff --git a/src/relay/pass/kind_check.cc b/src/relay/pass/kind_check.cc index 976a2ef8ec54d..c0f4a7c5967d1 100644 --- a/src/relay/pass/kind_check.cc +++ b/src/relay/pass/kind_check.cc @@ -32,7 +32,7 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ -#include +#include #include #include "../ir/type_functor.h" @@ -183,7 +183,7 @@ Kind KindCheck(const Type& t, const Module& mod) { return kc.Check(t); } -TVM_REGISTER_API("relay._ir_pass.check_kind") +TVM_REGISTER_API("relay._analysis.check_kind") .set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 1) { *ret = KindCheck(args[0], ModuleNode::make({}, {})); diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 9f56b22fc13e9..73c5fe3abc22c 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, @@ -46,6 +46,11 @@ namespace relay { */ class LetList { public: + ~LetList() { + if (lets_.size() > 0 && !used_) { + LOG(WARNING) << "letlist not used"; + } + } /*! * \brief insert a binding. * @@ -64,13 +69,13 @@ class LetList { /*! * \brief insert a binding. * - * \param ty the type of the binding. - * * \param expr the value of the binding. * + * \param ty the type of the binding. + * * \return a Var that hold the inserted expr. */ - Var Push(Type ty, Expr expr) { + Var Push(Expr expr, Type ty) { return Push(VarNode::make("x", ty), expr); } @@ -82,7 +87,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Expr expr) { - return Push(Type(), expr); + return Push(expr, Type()); } /*! @@ -129,6 +134,12 @@ class LetList { return ll.Get(f(&ll)); } + static Expr Let(const Expr& e, const std::function& f) { + return With([&](LetList* ll) { + return f(ll->Push(e)); + }); + } + private: std::vector > lets_; bool used_ = false; diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 3d77fabe6fe91..48a0dfb847466 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include "pattern_util.h" @@ -88,11 +88,44 @@ int64_t ConvMacCount(const Call& call_node) { << "The dimension of the output tensor in Conv 2D should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_attr->groups, 0) - << "The number of input channels is not divisble by groups."; + << "The number of input channels is not divisble by groups."; count *= input_channel/conv_2d_attr->groups; return count; } +int64_t Conv2dTransposeMacCount(const Call& call_node) { + if (!call_node->checked_type_.defined()) { + LOG(WARNING) << "The infer type pass should be called before the mac count pass"; + return 0; + } + Array args = call_node->args; + CHECK(args.size() == 2) + << "The number of input arguments of a CONV 2D Transpose node should be 2."; + const auto* conv_2d_transpose_attr = call_node->attrs.as(); + const auto* data_type = args[0]->checked_type().as(); + Array data_shape = data_type->shape; + std::string data_layout = conv_2d_transpose_attr->data_layout; + int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); + int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); + CHECK(C_ind != -1) + << "There is no input channel dimension."; + int64_t input_channel = static_cast(data_shape[C_ind].as()->value); + if (c_ind != -1) + input_channel *= static_cast(data_shape[c_ind].as()->value); + Array kernel_size = conv_2d_transpose_attr->kernel_size; + CHECK(kernel_size.size() == 2) + << "The dimension of the kernel in Conv 2D Transpose should be 2."; + const auto* expr = call_node->checked_type().as(); + Array output_tensor = expr->shape; + CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) + << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; + int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); + CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) + << "The number of input channels is not divisble by groups."; + count *= input_channel/conv_2d_transpose_attr->groups; + return count; +} + int64_t DenseMacCount(const Call& call_node) { if (!call_node->checked_type_.defined()) { LOG(WARNING) << "The infer type pass should be called before the mac count pass"; @@ -106,13 +139,13 @@ int64_t DenseMacCount(const Call& call_node) { Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; + << "The dimension of an input tensor to Dense node should be 2."; int64_t d1 = static_cast(data_shape[0].as()->value); int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); CHECK(d2 == d4) - << "The dimensions of input arguments do not match."; + << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } @@ -120,6 +153,9 @@ int64_t DenseMacCount(const Call& call_node) { RELAY_REGISTER_OP("nn.conv2d") .set_attr("FMacCount", ConvMacCount); +RELAY_REGISTER_OP("nn.conv2d_transpose") +.set_attr("FMacCount", Conv2dTransposeMacCount); + RELAY_REGISTER_OP("nn.dense") .set_attr("FMacCount", DenseMacCount); @@ -129,7 +165,8 @@ class MacCounter : private ExprVisitor { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { - LOG(INFO) << "This pass only counts MACs in direct CONV 2D and Dense ops"; + LOG(INFO) << "This pass only counts MACs in direct CONV 2D, " + << "CONV 2D Transpose and Dense ops"; MacCounter counter; counter(expr); return counter.count_; @@ -151,7 +188,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") +TVM_REGISTER_API("relay._analysis.GetTotalMacNumber") .set_body_typed(GetTotalMacNumber); } // namespace mac_count diff --git a/src/relay/pass/match_exhaustion.cc b/src/relay/pass/match_exhaustion.cc index 173d6eacf528f..cc00a54cde0ab 100644 --- a/src/relay/pass/match_exhaustion.cc +++ b/src/relay/pass/match_exhaustion.cc @@ -32,7 +32,6 @@ #include #include #include -#include #include namespace tvm { @@ -236,15 +235,15 @@ Array UnmatchedCases(const Match& match, const Module& mod) { } // expose for testing only -TVM_REGISTER_API("relay._ir_pass.unmatched_cases") -.set_body_typed(const Match&, - const Module&)>([](const Match& match, - const Module& mod_ref) { - Module call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = ModuleNode::make({}, {}); - } - return UnmatchedCases(match, call_mod); - }); +TVM_REGISTER_API("relay._analysis.unmatched_cases") +.set_body_typed(const Match&, const Module&)>( + [](const Match& match, const Module& mod_ref) { + Module call_mod = mod_ref; + if (!call_mod.defined()) { + call_mod = ModuleNode::make({}, {}); + } + return UnmatchedCases(match, call_mod); + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index b95c5844f8a40..3b7628a10789c 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file partial_eval.cc * @@ -64,7 +64,7 @@ * 3: The generated code reuses bindings (although they are not shadowed), * so we have to deduplicate them. * - * 4: In the generated code, multiple VarNode might have same Id. + * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. * While it is permitted, most pass use NodeHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. @@ -91,7 +91,8 @@ * * These assumptions do not affect the correctness of the algorithm, however. */ -#include +#include +#include #include #include #include @@ -215,9 +216,9 @@ Static MkSRef() { } using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; + const Attrs&, + const Array&, + LetList*)>; struct SFuncNode : StaticNode { Func func; @@ -255,6 +256,7 @@ class Environment { void Insert(const Var& v, const PStatic& ps) { CHECK(ps.defined()); + CHECK_EQ(env_.back().locals.count(v), 0); env_.back().locals[v] = ps; } @@ -286,12 +288,17 @@ class Environment { /*! * \brief As our store require rollback, we implement it as a frame. - * every time we need to copy the store, a new frame is insert. - * every time we roll back, a frame is popped. + * + * Every time we need to copy the store, a new frame is insert. + * Every time we roll back, a frame is popped. */ struct StoreFrame { std::unordered_map store; - /*! \brief on unknown effect, history_valid is set to true to signal above frame is outdated */ + /*! + * \brief On unknown effect, history_valid is set to true to signal above frame is outdated. + * + * It only outdate the frame above it, but not the current frame. + */ bool history_valid = true; explicit StoreFrame(const std::unordered_map& store) : store(store) { } StoreFrame() = default; @@ -309,6 +316,7 @@ class Store { } void Insert(const SRefNode* r, const PStatic& ps) { + CHECK(r); store_.back().store[r] = ps; } @@ -316,19 +324,21 @@ class Store { PStatic Lookup(const SRefNode* r) { auto rit = store_.rbegin(); while (rit != store_.rend()) { - if (!rit->history_valid) { - return PStatic(); - } if (rit->store.find(r) != rit->store.end()) { return rit->store.find(r)->second; } + if (!rit->history_valid) { + return PStatic(); + } ++rit; } return PStatic(); } void Invalidate() { - store_.back().history_valid = false; + StoreFrame sf; + sf.history_valid = false; + store_.push_back(sf); } private: @@ -340,6 +350,10 @@ class Store { store_->store_.push_back(StoreFrame()); } ~StoreFrameContext() { + // push one history valid frame off. + while (!store_->store_.back().history_valid) { + store_->store_.pop_back(); + } store_->store_.pop_back(); } }; @@ -425,8 +439,6 @@ TVM_ADD_FILELINE) Expr StripWithFuncId(const Expr& e); -Expr DeDup(const Expr& e); - Function AsFunc(const Expr& e) { if (e.as()) { return Downcast(e); @@ -443,13 +455,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars, - const Module& mod) : - mod_(mod) { - for (const Var& v : free_vars) { - env_.Insert(v, NoStatic(v)); - } - } + PartialEvaluator(const Module& mod) : mod_(mod) { } PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -485,23 +491,23 @@ class PartialEvaluator : public ExprFunctor return env_.Lookup(GetRef(op)); } - PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - GlobalVar gv = GetRef(op); + PStatic VisitGlobalVar(const GlobalVar& gv) { + CHECK(mod_.defined()); if (gv_map_.count(gv) == 0) { - if (mod_.defined()) { - Function func = mod_->Lookup(gv); - InitializeFuncId(func); - Func f = VisitFuncStatic(func, gv); - gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); - func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); - mod_->Update(gv, func); - } else { - gv_map_.insert({gv, NoStatic(gv)}); - } + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); } return gv_map_.at(gv); } + PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { + return VisitGlobalVar(GetRef(op)); + } + PStatic VisitExpr_(const LetNode* op, LetList* ll) final { env_.Insert(op->var, VisitExpr(op->value, ll)); return VisitExpr(op->body, ll); @@ -630,7 +636,7 @@ class PartialEvaluator : public ExprFunctor subst.Set(func->type_params[i], type_args[i]); } for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], Type()); + subst.Set(func->type_params[i], IncompleteTypeNode::make(kType)); } std::vector