diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 02134b64b619..97cf467cca07 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -33,10 +33,10 @@ ############################## # Top-level Fallbacks ############################## -include/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics -src/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics -apps/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics -python/** @tqchen @jroesch @yzhliu @icemelon9 @junrushao1994 @comaniac @zhiics +include/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics +src/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics +apps/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics +python/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics # Thirdparty license audit 3rdparty/** @tqchen @jroesch @@ -67,11 +67,11 @@ rust/** @jroesch @nhynes @nhynes vta/** @tmoreau89 @vegaluisjose # docs -docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 -tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 +docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon +tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon # tests -tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon9 +tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon ############################## # Specific modules @@ -129,9 +129,9 @@ include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm python/tvm/micro/** @areusch @liangfu @tmoreau89 @manupa-arm # relay -src/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -include/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -python/tvm/relay/** @jroesch @slyubomirsky @icemelon9 @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +src/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +include/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 +python/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 # relay/qnn diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1b9ebb3411e2..ca7d0f2a5052 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -71,6 +71,23 @@ jobs: run: >- conda build --output-folder=conda/pkg conda/recipe && conda install tvm -c ./conda/pkg + - name: Build iOS RPC@MacOS + if: startsWith(matrix.os, 'macOS') + run: | + IOS_VERSION="14.0" + CMAKE_FLAGS="-DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_SYSTEM_VERSION=${IOS_VERSION} \ + -DCMAKE_OSX_SYSROOT=iphonesimulator \ + -DCMAKE_OSX_ARCHITECTURES=x86_64 \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON \ + -DUSE_IOS_RPC=ON" + + mkdir build-ios-simulator + cd build-ios-simulator + cmake .. ${CMAKE_FLAGS} + cmake --build . --target ios_rpc - name: Test@Win if: startsWith(matrix.os, 'windows') shell: cmd /C call {0} @@ -81,3 +98,12 @@ jobs: shell: bash -l {0} run: >- python -m pytest -v tests/python/all-platform-minimal-test + - name: Test iOS RPC@MacOS + if: startsWith(matrix.os, 'macOS') + shell: bash -l {0} + run: >- + python -m pip install tornado psutil cloudpickle && + export PYTHONPATH=tests/python/contrib:${PYTHONPATH} && + export BUNDLE_ID=org.apache.tvmrpc && + export BUNDLE_PATH=build-ios-simulator/apps/ios_rpc/ios_rpc/src/ios_rpc-build/Release-iphonesimulator/tvmrpc.app && + python -m pytest -v tests/python/contrib/test_rpc_server_device.py diff --git a/.gitignore b/.gitignore index 7141aaeb192f..491116c163fd 100644 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,7 @@ perf .bash_history *.json *.params +*.ro *.onnx *.h5 synset.txt @@ -240,4 +241,4 @@ conda/pkg # Downloaded models/datasets .tvm_test_data .dgl -.caffe2 \ No newline at end of file +.caffe2 diff --git a/.gitmodules b/.gitmodules index 6ef740e33153..8dfda44d10a0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "3rdparty/libbacktrace"] path = 3rdparty/libbacktrace url = https://github.com/tlc-pack/libbacktrace.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000000..a3bcc6981d5d --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit a3bcc6981d5dad3afb212689e2c7853d1b1ee45d diff --git a/CMakeLists.txt b/CMakeLists.txt index 24f0653b3a78..7293abb60f7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,7 @@ tvm_option(USE_MICRO "Build with Micro TVM support" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) +tvm_option(USE_PT_TVMDSOOP "Build with PyTorch TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) tvm_option(USE_CMSISNN "Build with Arm CMSIS-NN" OFF) @@ -69,6 +70,7 @@ tvm_option(USE_MKLDNN "Build with MKLDNN" OFF) tvm_option(USE_DNNL_CODEGEN "Enable MKLDNN (DNNL) codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) tvm_option(USE_CUBLAS "Build with cuBLAS" OFF) +tvm_option(USE_CUTLASS "Build with CUTLASS" OFF) tvm_option(USE_THRUST "Build with Thrust" OFF) tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF) tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) @@ -413,6 +415,8 @@ endif(USE_PIPELINE_EXECUTOR) # Module rules include(cmake/modules/VTA.cmake) include(cmake/modules/StandaloneCrt.cmake) +include(cmake/modules/Zephyr.cmake) +include(cmake/modules/Arduino.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/Hexagon.cmake) include(cmake/modules/OpenCL.cmake) @@ -428,6 +432,7 @@ include(cmake/modules/contrib/EthosU.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) +include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/ExampleTargetHooks.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) @@ -437,6 +442,7 @@ include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) +include(cmake/modules/contrib/PT_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/BNNS.cmake) include(cmake/modules/contrib/ONNX.cmake) @@ -471,7 +477,7 @@ add_library(tvm SHARED $ $) + add_library(tvm_runtime STATIC $ $) set(NOTICE_MULTILINE "You have build static version of the TVM runtime library. Make " "sure to use --whole-archive when linking it into your project.") @@ -479,7 +485,7 @@ if(BUILD_STATIC_RUNTIME) add_custom_command(TARGET tvm_runtime POST_BUILD COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE}) else() - add_library(tvm_runtime SHARED $) + add_library(tvm_runtime SHARED $ $) set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") endif() set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") @@ -500,6 +506,8 @@ if(USE_MICRO) # Unix Makefiles generator, need to add these explicit target-level dependency) add_dependencies(tvm host_standalone_crt) add_dependencies(tvm_runtime host_standalone_crt) + add_dependencies(tvm_runtime zephyr) + add_dependencies(tvm_runtime arduino) endif() if(USE_CPP_RPC) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 6c63793fa217..d2c2745c8f85 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -59,7 +59,7 @@ We do encourage everyone to work anything they are interested in. - [Giuseppe Rossini](https://github.com/giuseros): @giuseros - aot, arm - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - frontends - [Junru Shao](https://github.com/junrushao1994) (PMC): @junrushao1994 - relay, compiler -- [Haichen Shen](https://github.com/icemelon9) (PMC): @icemelon9 - relay, topi +- [Haichen Shen](https://github.com/icemelon) (PMC): @icemelon - relay, topi - [Siva Rama Krishna Reddy](https://github.com/srkreddy1238): @srkreddy1238 - frontends, golang - [Zhixun Tan](https://github.com/phisiart): @phisiart - opengl, web - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - topi, compiler, runtime @@ -90,6 +90,7 @@ We do encourage everyone to work anything they are interested in. - [Siyuan Feng](https://github.com/Hzfengsy): @Hzfengsy - [Josh Fromm](https://github.com/jwfromm): @jwfromm - [Sergei Grechanik](https://github.com/sgrechanik-h): @sgrechanik-h +- [Mehrdad Hessar](https://github.com/mehrdadh): @mehrdadh - [Bohan Hou](https://github.com/spectrometerHBH): @spectrometerHBH - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - [Luke Hutton](https://github.com/lhutton1): @lhutton1 @@ -130,12 +131,14 @@ We do encourage everyone to work anything they are interested in. - [Giuseppe Rossini](https://github.com/giuseros): @giuseros - [Siju Samuel](https://github.com/siju-samuel): @siju-samuel - [Junru Shao](https://github.com/junrushao1994): @junrushao1994 -- [Haichen Shen](https://github.com/icemelon9): @icemelon9 +- [Haichen Shen](https://github.com/icemelon): @icemelon - [Xingjian Shi](https://github.com/sxjscience): @sxjscience +- [Mark Shields](https://github.com/mbs-octoml): @mbs-octoml - [Christopher Sidebottom](https://github.com/mousius): @mousius - [Siva Rama Krishna Reddy](https://github.com/srkreddy1238): @srkreddy1238 - [Dmitriy Smirnov](https://github.com/d-smirnov): @d-smirnov - [Jon Soifer](https://github.com/soiferj): @soiferj +- [Chris Sullivan](https://github.com/csullivan): @csullivan - [Zhixun Tan](https://github.com/phisiart): @phisiart - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch - [Jorn Tuyls](https://github.com/jtuyls): @jtuyls @@ -148,6 +151,7 @@ We do encourage everyone to work anything they are interested in. - [Logan Weber](https://github.com/weberlo): @weberlo - [Matt Welsh](https://github.com/mdw-octoml): @mdw-octoml - [Jian Weng](https://github.com/were): @were +- [wrongtest](https://github.com/wrongtest): @wrongtest - [Yong Wu](https://github.com/yongwww): @yongwww - [Zhao Wu](https://github.com/FrozenGene): @FrozenGene - [Bing Xu](https://github.com/antinucleon): @antinucleon @@ -155,6 +159,7 @@ We do encourage everyone to work anything they are interested in. - [Hao Yu](https://github.com/comaniac): @comaniac - [Joshua Z. Zhang](https://github.com/zhreshold): @zhreshold - [Lianmin Zheng](https://github.com/merrymercy): @merrymercy +- [Xiyou Zhou](https://github.com/zxybazh): @zxybazh ## List of Contributors - [Full List of Contributors](https://github.com/apache/tvm/graphs/contributors) diff --git a/Jenkinsfile b/Jenkinsfile index 6a7a1f4e3d36..a5e1d2824566 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -46,8 +46,8 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.67" -ci_gpu = "tlcpack/ci-gpu:v0.77" -ci_cpu = "tlcpack/ci-cpu:v0.78" +ci_gpu = "tlcpack/ci-gpu:v0.78" +ci_cpu = "tlcpack/ci-cpu:v0.79" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.74" ci_qemu = "tlcpack/ci-qemu:v0.08" @@ -353,7 +353,7 @@ stage('Unit Test') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_i386} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_unittest.sh" - sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration.sh" + sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_integration_i386only.sh" sh "${docker_run} ${ci_i386} ./tests/scripts/task_python_vta_fsim.sh" junit "build/pytest-results/*.xml" } diff --git a/KEYS b/KEYS index a3b7b0e3149f..a819d8f3bdda 100644 --- a/KEYS +++ b/KEYS @@ -300,3 +300,119 @@ Q2YsCQ/Br0XhJvC+i6OYgCI1iGLINTe9wjsi2ei8ZI+2G9XY62sN0orIIjIadns+ 8WGuWI9h3RBLY7aFMLpl02cXrsOiMcXC1Uk/e6e14Xpu+Y6IG4KKkUM= =GEwA -----END PGP PUBLIC KEY BLOCK----- +pub rsa4096 2021-11-10 [SC] + C5E5C09030E7BD32DF9A67CE35ABC9676004ADAE +uid [ultimate] Junru Shao +sig 3 35ABC9676004ADAE 2021-11-10 Junru Shao +sub rsa4096 2021-11-10 [E] +sig 35ABC9676004ADAE 2021-11-10 Junru Shao + +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBGGL/BwBEAC8krTtoeZUNgWVTEBZ8Gm77xwy0W1NjpqY6+cT01xW1vlsjMBl +MoR2bGA8aR+vNERI/CfRN8uyplWyoCK7fbnk0Rcd81nvdMqYiXWg55PetJ+oPm+B +8j26ssUe+Umg0cwa4ZUbdmicSSlousjR75XlasanrGggn1iH0ltiwvkyxKIlppo9 +UTgh6Db3sK3i+hNrwZmMliG03CpdZqh9luCQD2KaHhL2v63fzEo2mKJLHFQGYmRR +dkCvF8GNEkoyVbOVRY+jnZ97C6U4XigAwwqi7kBp9QJ5DE7xzjXwCOS2QNUzvdjF +/OE3zTVJmx5qSD9i5u69A3iXBEfVd19gCDiJIkOttgfNgKy+atK+Bmc5iM9aizCA +jZQAt0uOsPXzNkiiiJoTp6egvt05F/7Z/cy+UQZb+GQNRqMr+8Z77QjH1fAAB8qz +q+Z/W6Gazws+CiqkVrvMUKCIj3AxHWeiUDwD1KGap3WkpocEuJ2IXuYUlySDIFXv +Iigm0a0KFt8Ex4cfz3GNS6eH0bjHn6YIebIQIRRYI4kozy/JMAYJ78Tx8Rp0WY38 +85PXQZazHRriVttc8YrnK8uAHjN01COOyGkwYp20Xqw7dOoYCnbhObYvoDDHRtMm +2O7TtK6sfnyWhL9ZRGOWyqoIw+4TIh+sS0z1dj7oyWeSaPHTCbj/7CneZwARAQAB +tCFKdW5ydSBTaGFvIDxqdW5ydXNoYW9AYXBhY2hlLm9yZz6JAk4EEwEKADgWIQTF +5cCQMOe9Mt+aZ841q8lnYAStrgUCYYv8HAIbAwULCQgHAgYVCgkICwIEFgIDAQIe +AQIXgAAKCRA1q8lnYAStrinMD/96v0V5JOtvT2+NzkxyoZPFw/1H/jtAoCAm2IUq +PhUGibAPztREBcbr40I8l8bLghvN3PyNFop/TY7uxwzTzJrST1eZxML6x75pw6QK +2dbY0fFV3SEucDd8mCtVk/5F5ZWd7pXfYq4HVIcSikL0RbKHEl7N8fCRQHBQ63OA +MugeAnTfGhppQHLJQtN9iKx4iHt5aH38MMlhlzfqEwjMEfCm0OnnEWjLbjQgCFTU +1llnQEWxT1kwsiHKNvuTSuLrSP5SHsE/VGixLWUw3YzvDFrP5pmnY7XRz4jAynrS +QIoKnb6WtKCuos1Ym9gZIqXlPKZWfL93FBqD+lmHBMoPIVlubAOGR5scRd7sWhDd +ECnRWQZmIQ4b6g8dmcFQ/vC+1G75hr3EZEZHX6F3tS4lLHZ9NxiKK49ctD6UIJVP +4FIAOY+lB1LDVRObm4KuQ9sLO60p7Bh3xqEzqDZRLwO+z3vo7nQl+F+SwWRcI2tN +BrDaM+MDIrBiwPH79Ehi7r4fFVqzmHDvqa0eBjUnVx9g6AnlR91/4QX9ZLt+rUlg +ufJC35fUSJpRLLWUIAto8veLv7rd5mwbeocnncAXlx3+rN9NDEiFqeNeijLfv9bb +VXa7f1+vfyTn+ZrxB6vGM76bzZosJUWBhrVcq6Pv0Llxowy11z8tMBgALoKnwPhQ +KyBcNbkCDQRhi/wcARAAqA2X+BDf2aUFaMdSOGfxTf3y/moLREw0xw4zfGzpeMjY +ln+GrziX/+3bdwiDw8fwbe/r6M7jRW+66ndzI8J3qz6mpZpYbSUYdSpThqn2M/Pn +cwFjzP9hn5436MoiO+EPz7dukmXq1+a7L7arQUdpQ+LReFg31M8uDiaKmOGBibGw +2NmyD9NRsWsWn4thn4lu4ir1tSfgkSlSJPQyGF22Y1h6I5serjAbLqrXFG8+ziKv +HBXofYvQEnHynPzByJUy1CxAKojyvR+ARiSfhW2EOlB5USLjjGvgIKBko912EYU1 +s2GblBPkdBgHpMaVq4+uUdQcAvOpsscsoMMB3GQdhnMHrZGMjN+fPbMer4w721yo +495IOFGE97XSiO/1CPpVzIOPzl+QpSuRdl/GlKr30+vEUwTSXUYEbYSRMCofiRvv +63g6+dC0aN/8yVmnXCbehPu2EOmD5kl4VwrIADy7D1vIpXqetfIXPToovJo1wc/m +ZNXDXDnEImP2vQMuIb8pF/G66yfIiFTkvlORp3uA+G3wujnqq7eouseBx3vC7gap +fsSLqnMTCtZgh+qrogbeQzTNSVBQ4K6i1Ipbq+ti/ebRSMBf4WXeByD5Sk2+K1vo +5njW/8yXgg4zxpHdZo+s2RtpIzYQjjQFRFstR6RbBdcl7H348arvQiucyeYmK80A +EQEAAYkCNgQYAQoAIBYhBMXlwJAw570y35pnzjWryWdgBK2uBQJhi/wcAhsMAAoJ +EDWryWdgBK2uNdAQAI1FtRJ4mI6EOLjk9L9b/P3l5X0VY68c6eMMRc53goRr6cMj +1DlEGMSrFZ/uxadpVhdr7XZSUJy2CP8XwL7MOzkzGdshki1CgqECkkm4PPjBYUlJ +/aNPcQuaz7C6DF4X190Q81dCWG3nFzN1jJ8th+IRzTT6y1xJzMoslqeXqNf5sHyT +3tPkgLNcoFvUBmLglGlWOiuiWSkI+FFi+azGzgplPPWQiFEf47N5iEyhOLJYFkF+ +fR0u056EdTLV2pMqKU+9OEbB0gO8c2+hNXj3O/g+d2GsszrxHzLWwiX2haLfAcD8 +Eu8HBTp6nIa+q7kEAhhEoT3KPGTvIKFEtKzQmW9qa9XtEXjLHnmrMURGw1epVsE4 +/c1u5BughEZi3yw+yupnkRa7uR/IJw6Iw27OHYg9fyqMkGvfT0se9JTgud9GYggA +iaibIEq6K1sKjTE6Mk6KyGuQR6OrI7DB9HueFG6GP4UpZHXMdgqvlYXtn7iKTz8s +H/r/Ge2qzbQOFfJgZ/pI/7LL65XlgAbsSo79neztm6ExN5u9QBkpjsKYk2Gj8eny +vDrH4rzP6lkvLqCpCnOI+NHvmTpHI6XCi7XmQzBnBI7YDlNAuyx0axhkMCZs0bFx +lFYYlF0zWyTPNpVGZj3hMWq1mpsBOY4SWtN2T5gLoCGEPrgrZn1Gc4xFHnve +=jNCy +-----END PGP PUBLIC KEY BLOCK----- +pub rsa4096 2021-11-10 [SC] + 43C2306EBED3D533F2CFBA8A2C75E5A496C80880 +uid [ultimate] Wuwei Lin +sig 3 2C75E5A496C80880 2021-11-10 Wuwei Lin +sub rsa4096 2021-11-10 [E] +sig 2C75E5A496C80880 2021-11-10 Wuwei Lin + +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBGGMC8kBEAC0Gto+SCmuyo1CxQQkUwjkSgaU8eIkTNTHi3sCCZLC8nz1W7XY +Ksf/QgZQxdA1mrI41M5OWgg5fh0iySP6aU/C/WPCm6ebe4/VAo2BTJy4a2TgM3pg +LEd63mWj8XUAXLiQY2yju3PhkY3DCfVwghc81qqm21Ny2uGYe3w56N3OSOGUnrp2 +1zcALDvYhE00fszbPlDtpZ+YRB1Xf/NBEd7TW3fLzrDP5MqdhUt0TtdpfW+at6Op +6xW3uXWWbfkXYFMaM6xdatzqovMwPMMyUi1mnPRLL1i3toezm7GSAdgWHxlZBNEU +lyyg122NuDhx5Rbri4qHTUdiM5ZwxhuehJGRV1QmpRG2n6XVjgtEICL0NjoFTcVv +g/9EqSLKKmXo2qF/ubMepngX9nDcSlLLH5zQe2hvup3FSGep+SLJ8+sMjf01H/Vo +KiTPAN/C4LRsWiJTfHrUANminORo6+FMrqougk3QVd1X7X0MQFtxV5JPj5A0YdE2 +YkJAVMCOi/YGtBpktodZBpChMojjjlb0QyWo0YCLUcBIoA/Y6vCm1Z5CRYzUdvWt +SS94viKXpNJQaSospoxk1uWeEXM6Sj4/J52HGv86gP8EZwcbSjE6yRydokgB9qyW +nuIWEKFiDl74heAL1wtpk+aS3UVjb2dInnpALmNzNMl1UYwJszViTraz0QARAQAB +tBxXdXdlaSBMaW4gPHd1d2VpQGFwYWNoZS5vcmc+iQJSBBMBCAA8FiEEQ8Iwbr7T +1TPyz7qKLHXlpJbICIAFAmGMC8kCGwMFCwkIBwIDIgIBBhUKCQgLAgQWAgMBAh4H +AheAAAoJECx15aSWyAiAKlIP/RhCvkX4evnIlgDTNVt7W/XMFUua638mAj3p752M +FnH7FU/OTySE5wc/P4LJI7kNBLC9doF6RSpjrE87lSBRhYyPU7LVlTX5j5xbt3HD +nVZWe1XAj3wORR9mYDJaUABCY21qLBY2WGDeI3qGAQ5vjw/13HoYZAKcsQ9T8FN6 +FM+T6endSJUkqKSNLw+PiUAqosqI3ZgbShleD9jdwHzNqldwGWV47wJCS1UoOfnu +2I63EluPhOO+F44KXs0mAoEQqeqpuA4oXeyGhkbePR4xGIqqCDev1Gpr3KXDE0SH +4blXbKIEqWUYU3lU3/uUs8noaARkaNYkvfyNxKXXnyVqKFAPEZkGU+Nwp4VrtVp1 +wlqmnebzxVDWpxrkQrtr2sNDSYbJPC5fQqx4DyWctPNGWDRKmac25dW5JSbkFkVY +nBqFu464LNMtNS3RUL6cegFcV6Put+wYdqzV27BOaU0nnPGOrf2zDsVZdg6msfNB +eN/gABzRgW1iCuCItkwv2uDlabW/S6EV3Rkz9EVXNNoiPC6OwjVZAPvbB5tzmA9y +gCAsUWYjWH0VR5HuNmUIu76pDuGQVz7dk3xq6P+KF7LhX07oq5wAcBjxD66tKFIj +dIMfnJqu3Uy4UkF7cExg+IlZYsYyC2nBb0o8qDI+eVCEN5iLR+fr3OFKswhqFdMt +VWGcuQINBGGMC8kBEAChBfP599l60dioP51mR4s10mifMY/Ot+E8z8oAvvq0bQky +6Y+BcOWghHQ9dKsJ+UIQJhQHGKMVqgoVIy4rC+nVXcN5tLec4b8pKESJuLdcQ7P9 +1j03v31XvbpNmAUuUKl0xEkrHsRUlL9yfC6M8/PnZm9FImJmQWCageyl+T/zlDzy +LnZwQ7ko7mCF3haRBqCTuYpT6ICuZ0Pg/itVuje8WNkFH+kPH4Z6JlTboNoVf/UP +xcQYrnCwRtoQPdJb0jz2pTjKqtBirrKewVPE4meoZnUK6Q+h+yx36jTM9IqvP59F +/sW3kkQuHVKZj22qSyILBHxFJ1qjndjkIe5IX6w4bqXIEZWgXBJJggYqeBqytWhb +mx836Hf6oR6wlytG8M0NgkMMziPzK6hpns9swIdcPngHLn6XyNT7WxLZwMmh2xEd +P53qzyo8HAl3uIUQzz9QabOvUyEiw4PNaxyuqpPvyhXcmlRjfSs6NRceYyhXdUTA +lcKKMsZwNZ/i/rYYME5eVtEpRKmc6ZnbDRk+2la2RdJikRVzP4LAUut+yi5n/cal +qKW4685BC/aDmCWmQLAGZtSxWNBeTMnp5NpvVG/5LLSBHuraJePiOORXpFCdIira +BWsrHj1AfP831Byj53MMHS8C5Xr5J5JiQWKhxd5ASWPu4DjT3kAkRVZrvvpPfQAR +AQABiQI2BBgBCAAgFiEEQ8Iwbr7T1TPyz7qKLHXlpJbICIAFAmGMC8kCGwwACgkQ +LHXlpJbICIB7EhAAg2uspz5Vsw6QK76ipdkSAgUHeZU0MU3/af6qqrkseB1hAnck +E6fb1hUeRy4o5550eREgMi0uJDTqAoXvZ01oIKrfdZOsr1xLPHRrziBDvmSZVQmt +tIoMuEDhD8Pf7PNVemAIKQLqoleHeXKSlc1FP6DKcAIJK7jvIkb1alO9r9gXTQrM +8rHY4KSRh545HtZva6gBZjk+RfpQu6Sg/dMlwlDxTpoH0QjNalwzHD09sK9DrpOf +OhdTb3dYMBAMPyPWudUW0JbHhlJMqykCWdMSN5FxQIDcz4N4sH3idclOqBWzQq8Z +igf4cdBGaegHPGxOEMRdAKDOkVxP2ZwxJBLUFBThD/CfGRGhnwNTYoNpaPPekRPW +7Yg2JCnqI2pVGQBETX57J3wcQDb/TXQ7VP+ZttHMkGkU7IGoBdlzu9hhPapUs032 +Fy5AYoRozj9SuLKGbqy7VkvtEVZ7TeKaZO34fEJ3uRkDTHx0TQtqwvs1b3U1fWJj +o7469h/jBIPHJojx088Om0pMv91xJ7nQ3xukgVw9C0DZfmBX3xd2bNbyfugT8rqQ +A1PPxm4/KsXX/IZZOuM/tlT0vAahQsvXMNUVMg7v/PWuB6V47UdenKpXd10oloF7 +MMtVW5sxG8OoBpUIhJUCtYTlwGCyGWSR7+rsHSR2HydLk1RWcYNI3XgJ0ng= +=+gLd +-----END PGP PUBLIC KEY BLOCK----- diff --git a/LICENSE b/LICENSE index 52b2219396d2..18718f986baa 100644 --- a/LICENSE +++ b/LICENSE @@ -238,3 +238,8 @@ The Unlicense ------------- 3rdparty/rang + +BSD 3-Clause "New" or "Revised" License +--------------------------------------- + +3rdparty/cutlass \ No newline at end of file diff --git a/README.md b/README.md index 09ceb7ab1d07..d96038d17804 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ TVM is licensed under the [Apache-2.0](LICENSE) license. Getting Started --------------- Check out the [TVM Documentation](https://tvm.apache.org/docs/) site for installation instructions, tutorials, examples, and more. -The [Getting Started with TVM](https://tvm.apache.org/docs/tutorials/get_started/introduction.html) tutorial is a great +The [Getting Started with TVM](https://tvm.apache.org/docs/tutorial/introduction.html) tutorial is a great place to start. Contribute to TVM diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index b89bedbc6d45..3adcb2dc8d42 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/hexagon_launcher/README.md b/apps/hexagon_launcher/README.md index 85e6897b74a3..b190dd81a7b2 100644 --- a/apps/hexagon_launcher/README.md +++ b/apps/hexagon_launcher/README.md @@ -40,29 +40,33 @@ tvm_runtime, as well as the Hexagon launcher shared library and its correspondin tvm_runtime. As described in the [Manual compilation](#Manual compilation) section each component requires Hexagon and android dependencies. When building the launcher along with TVM these configurations must be providing when invoking cmake. A minimal -example invocation for compiling TVM along with the Hexagon launcher is included below, +example invocation for compiling TVM along with the Hexagon launcher is included below: ``` -cmake -DCMAKE_MAKE_PROGRAM=make \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ +cmake -DCMAKE_C_COMPILER=/path/to/clang \ + -DCMAKE_CXX_COMPILER=/path/to/clang++ \ -DCMAKE_CXX_FLAGS='-stdlib=libc++' \ -DCMAKE_CXX_STANDARD=14 \ - -DUSE_LLVM=/path/to/hexagon/llvm/bin/llvm-config \ + -DUSE_LLVM=/path/to/llvm/bin/llvm-config \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ -DUSE_HEXAGON_LAUNCHER=ON \ - -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ - -DANDROID_PLATFORM=android-28 \ - -DANDROID_ABI=arm64-v8a \ - -DUSE_HEXAGON_ARCH=v68 \ -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ - -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/Toolchain/ .. + -DUSE_HEXAGON_TOOLCHAIN=/path/to/hexagon/toolchain/ .. + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_ANDROID_TOOLCHAIN=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + .. ``` +where `v65|v66|v68` means "one of" these architecture versions. The Hexagon launcher application is an android binary and thus requires the use of an android toolchain for compilation. Similarly, the Hexagon tvm runtime requires the use of the Hexagon toolchain and depends on the Hexagon SDK. The -resulting hexagon launcher binaries can be found in the `launcher` subdirectory -of the cmake build directory. +resulting hexagon launcher binaries can be found in the `apps_hexagon_launcher` +subdirectory of the cmake build directory. Please note that the above command +will not build support for Hexagon codegen in the TVM library, for that please +additionally define the `USE_HEXAGON_DEVICE` variable. Also, the LLVM used in +`USE_LLVM` should have Hexagon target built in. ### Manual compilation @@ -72,43 +76,44 @@ code first. #### Compilation of the Hexagon part -1. Build the static version of TVM runtime for Hexagon. Use Hexagon clang - from the Hexagon SDK. This step is the same as building the shared version, - except at the cmake step, add `-DBUILD_STATIC_RUNTIME=ON`. The compilation - step should create `libtvm_runtime.a`. - -2. Create a subdirectory for the build files, and run `cmake` with the - following variables set: - - `FASTRPC_LIBS=SKEL` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=hexagon-clang` - - `CMAKE_CXX_COMPILER=hexagon-clang++` - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 - - `TVM_RUNTIME_HEXAGON=/path/to/libtvm_runtime.a` _statically_ linked - TVM runtime +Create a subdirectory for the build files, and run `cmake` with the +following variables set: - Make sure to provide the path to launcher's `CMakeLists.txt` directory - in `cmake` invocation. +``` +cmake -DCMAKE_C_COMPILER=/path/to/hexagon-clang \ + -DCMAKE_CXX_COMPILER=/path/to/hexagon-clang++ \ + -DUSE_HEXAGON_ARCH=v65|v66|v68 \ + -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ + /path/to/apps/hexagon_launcher/cmake/hexagon +``` -3. Run `make`. This will create `liblauncher_rpc_skel.so`. +Run `make`. This will create `liblauncher_rpc_skel.so`. The static version of +the TVM runtime for Hexagon will be built as a part of the process. #### Compilation of the Android part -1. Build TVM runtime for Android, using clang for AArch64 from the Android - NDK. Unlike in the Hexagon case, this should be the dynamic library (which - is the default), i.e. `libtvm_runtime.so`. - 2. Create a subdirectory for the build files (different from the one used for Hexagon files), and run `cmake` with the following variables set: - - `FASTRPC_LIBS=STUB` - - `USE_HEXAGON_SDK` to the path to the Hexagon SDK - - `CMAKE_C_COMPILER=aarch64-linux-android28-clang` (or later) - - `CMAKE_CXX_COMPILER=aarch64-linux-android28-clang++` (or later) - - `USE_HEXAGON_ARCH` to one of v65, v66, v68 (same as for the Hexagon part) - - `TVM_RUNTIME_ANDROID=/path/to/libtvm_runtime.so` dynamically or - statically linked TVM runtime - -3. Run `make`. This will create `launcher_android`. + +``` +cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DUSE_HEXAGON_SDK=/p/Hexagon_SDK/4.3.0.0 + -DUSE_HEXAGON_ARCH=v65|v66|v68 + /path/to/apps/hexagon_launcher/cmake/android +``` + +Run `make`. This will create `launcher_android`. The TVM runtime for Android will +be built as a part of the process. Depending on the version of cmake that you are +using, you may see the following warnings---they can be ignored. + +``` +An old version of CMake is being used that cannot automatically detect +compiler attributes. Compiler identification is being bypassed. Some +values may be wrong or missing. Update to CMake 3.19 or newer to use +CMake's built-in compiler identification. +``` ## Execution diff --git a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake index 4a7f803ce1ab..abf877cb67f1 100644 --- a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake +++ b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake @@ -15,11 +15,6 @@ # specific language governing permissions and limitations # under the License. -if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND - NOT "${FASTRPC_LIBS}" STREQUAL "STUB") - message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") -endif() - if(NOT DEFINED USE_HEXAGON_SDK) message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") endif() diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index c000b0e97cad..7716cde99863 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -21,17 +21,15 @@ project(HexagonAndroidLauncher C CXX) include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") add_custom_command( - OUTPUT ${LAUNCHER_RPC_STUB_C} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - COMMAND ${CMAKE_COMMAND} -E rename "${LAUNCHER_RPC_H}" - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + OUTPUT ${LAUNCHER_RPC_STUB_C} ${LAUNCHER_RPC_H} + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" ) include_directories(SYSTEM "${HEXAGON_SDK_INCLUDES}" "${HEXAGON_RPCMEM_ROOT}/inc" + "${CMAKE_CURRENT_BINARY_DIR}" # Output of qaic will go here ) link_directories(${HEXAGON_REMOTE_ROOT}) @@ -46,8 +44,9 @@ set(STUB_SRCS ) add_executable(launcher_android - "${STUB_SRCS}" + "${LAUNCHER_RPC_H}" "${LAUNCHER_RPC_STUB_C}" + "${STUB_SRCS}" ) ExternalProject_Add(android_tvm_runtime @@ -66,12 +65,14 @@ ExternalProject_Add(android_tvm_runtime ) ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) ExternalProject_Add_Step(android_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/libtvm_runtime.so + ${CMAKE_CURRENT_BINARY_DIR} DEPENDEES install ) add_dependencies(launcher_android android_tvm_runtime) -add_library(tvm_runtime SHARED IMPORTED) -set_target_properties(tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") +add_library(a_tvm_runtime SHARED IMPORTED) +set_target_properties(a_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.so") -target_link_libraries(launcher_android cdsprpc log tvm_runtime) +target_link_libraries(launcher_android cdsprpc log a_tvm_runtime) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index c76fcccc5a1a..3f99459f3a49 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -22,12 +22,14 @@ include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") add_custom_command( OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_H} - COMMAND ${QAIC_EXE} ${QAIC_FLAGS} - "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" ) -include_directories(SYSTEM ${HEXAGON_QURT_INCLUDES}) +include_directories(SYSTEM + ${HEXAGON_QURT_INCLUDES} + ${CMAKE_CURRENT_BINARY_DIR} # Output of qaic will go here +) link_directories(${HEXAGON_QURT_LIBS}) @@ -48,8 +50,9 @@ set(SKEL_SRCS "${LAUNCHER_SRC}/launcher_core.cc" "${LAUNCHER_SRC}/launcher_hexagon.cc" ) + add_library(launcher_rpc_skel SHARED - "${LAUNCHER_SRC}/${LAUNCHER_RPC_H}" + "${LAUNCHER_RPC_H}" "${LAUNCHER_RPC_SKEL_C}" "${SKEL_SRCS}" ) @@ -71,14 +74,10 @@ ExternalProject_Add(static_hexagon_tvm_runtime BUILD_ALWAYS ON ) ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) -ExternalProject_Add_Step(static_hexagon_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${CMAKE_INSTALL_PREFIX} - DEPENDEES install -) add_dependencies(launcher_rpc_skel static_hexagon_tvm_runtime) -add_library(static_tvm_runtime STATIC IMPORTED) -set_target_properties(static_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") +add_library(h_tvm_runtime STATIC IMPORTED) +set_target_properties(h_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") -target_link_libraries(launcher_rpc_skel -Wl,--whole-archive static_tvm_runtime -Wl,--no-whole-archive) +target_link_libraries(launcher_rpc_skel -Wl,--whole-archive h_tvm_runtime -Wl,--no-whole-archive) diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc index 6a5704d3888a..0fe9f9f59e4a 100644 --- a/apps/hexagon_launcher/launcher_core.cc +++ b/apps/hexagon_launcher/launcher_core.cc @@ -148,12 +148,13 @@ const tvm::runtime::PackedFunc get_module_func(tvm::runtime::Module module, } void reset_device_api() { - const tvm::runtime::PackedFunc api = get_runtime_func("device_api.cpu"); + const tvm::runtime::PackedFunc api = get_runtime_func("device_api.hexagon.v2"); tvm::runtime::Registry::Register("device_api.hexagon", true).set_body(api); } tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::runtime::PackedFunc loader = get_runtime_func("runtime.module.loadfile_so"); + static const tvm::runtime::PackedFunc loader = + get_runtime_func("runtime.module.loadfile_hexagon"); tvm::runtime::TVMRetValue rv = loader(file_name); if (rv.type_code() == kTVMModuleHandle) { return rv.operator tvm::runtime::Module(); @@ -169,7 +170,10 @@ tvm::runtime::Module create_graph_executor(const std::string& graph_json, uint64_t device_type = device.device_type; uint64_t device_id = device.device_id; + std::string linked_params = "tvm.runtime.hexagon.lookup_linked_params"; + const tvm::runtime::PackedFunc lookup_linked_params = get_runtime_func(linked_params); // Use default param lookup function (linked into the module). - tvm::runtime::TVMRetValue rv = create_executor(graph_json, graph_module, device_type, device_id); + tvm::runtime::TVMRetValue rv = + create_executor(graph_json, graph_module, lookup_linked_params, device_type, device_id); return rv.operator tvm::runtime::Module(); } diff --git a/apps/hexagon_launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h index f2aa8f10d0a6..91384133ab7b 100644 --- a/apps/hexagon_launcher/launcher_core.h +++ b/apps/hexagon_launcher/launcher_core.h @@ -89,6 +89,8 @@ struct Model { static tvm::Device device() { return tvm::Device{static_cast(kDLHexagon), 0}; } + static tvm::Device external() { return tvm::Device{static_cast(kDLCPU), 0}; } + tvm::runtime::PackedFunc run; }; diff --git a/apps/hexagon_launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc index 0a5d1f55e0c2..6925e1da9bfa 100644 --- a/apps/hexagon_launcher/launcher_hexagon.cc +++ b/apps/hexagon_launcher/launcher_hexagon.cc @@ -26,6 +26,8 @@ extern "C" { #include } +#include + #include #include #include @@ -106,7 +108,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int inpu DLTensor tensor{ const_cast(input_value), - Model::device(), + Model::external(), meta->ndim, meta->dtype, const_cast(meta->shape), @@ -153,6 +155,16 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out tvm::runtime::PackedFunc get_output = get_module_func(TheModel->graph_executor, "get_output"); tvm::runtime::NDArray output = get_output(output_idx); + std::vector shape_vec{output->shape, output->shape + output->ndim}; + + auto* container = new tvm::runtime::NDArray::Container( + static_cast(output_value), shape_vec, output->dtype, Model::external()); + container->SetDeleter([](tvm::Object* container) { + delete static_cast(container); + }); + + tvm::runtime::NDArray host_output(GetObjectPtr(container)); + if (meta_size != 0) { auto* meta = reinterpret_cast(output_meta); if (meta_size < meta->meta_size(output->ndim)) { @@ -170,8 +182,7 @@ AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int out return error_too_small(__func__, "value_size", value_size, data_size); } - auto data = reinterpret_cast(output->data); - std::copy(data, data + data_size, output_value); + host_output.CopyFrom(output); } return AEE_SUCCESS; diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 3f8c647fa4f2..9b476bbd47ce 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -94,6 +94,7 @@ - (void)open { server_.port = self.proxyPort.text.intValue; server_.key = self.proxyKey.text; server_.custom_addr = [NSString stringWithUTF8String:args.custom_addr]; + server_.verbose = args.verbose; server_.delegate = self; [server_ start]; diff --git a/apps/microtvm/arduino/template_project/crt_config/crt_config.h b/apps/microtvm/arduino/template_project/crt_config/crt_config.h index cf73103aff8b..b3126cfac920 100644 --- a/apps/microtvm/arduino/template_project/crt_config/crt_config.h +++ b/apps/microtvm/arduino/template_project/crt_config/crt_config.h @@ -36,7 +36,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index e285ecc6e3b0..1768c61197a9 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -31,11 +31,14 @@ import tempfile import time from string import Template +import re import serial import serial.tools.list_ports from tvm.micro.project_api import server +_LOG = logging.getLogger(__name__) + MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar" API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd()) BUILD_DIR = API_SERVER_DIR / "build" @@ -43,6 +46,10 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() +# Used to check Arduino CLI version installed on the host. +# We only check two levels of the version. +ARDUINO_CLI_VERSION = 0.18 + BOARDS = API_SERVER_DIR / "boards.json" @@ -77,6 +84,11 @@ class BoardAutodetectFailed(Exception): server.ProjectOption( "verbose", help="True to pass --verbose flag to arduino-cli compile and upload" ), + server.ProjectOption( + "warning_as_error", + choices=(True, False), + help="Treat warnings as errors and raise an Exception.", + ), ] @@ -91,7 +103,7 @@ def server_info_query(self, tvm_version): return server.ServerInfo( platform_name="arduino", is_template=IS_TEMPLATE, - model_library_format_path=MODEL_LIBRARY_FORMAT_PATH, + model_library_format_path="" if IS_TEMPLATE else MODEL_LIBRARY_FORMAT_PATH, project_options=PROJECT_OPTIONS, ) @@ -275,7 +287,25 @@ def _find_modified_include_path(self, project_dir, file_path, include_path): # It's probably a standard C/C++ header return include_path + def _get_platform_version(self, arduino_cli_path: str) -> float: + # sample output of this command: + # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' + version_output = subprocess.check_output([arduino_cli_path, "version"], encoding="utf-8") + full_version = re.findall("version: ([\.0-9]*)", version_output.lower()) + full_version = full_version[0].split(".") + version = float(f"{full_version[0]}.{full_version[1]}") + + return version + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Check Arduino version + version = self._get_platform_version(options["arduino_cli_cmd"]) + if version != ARDUINO_CLI_VERSION: + message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}." + if options.get("warning_as_error") is not None and options["warning_as_error"]: + raise server.ServerError(message=message) + _LOG.warning(message) + # Reference key directories with pathlib project_dir = pathlib.Path(project_dir) project_dir.mkdir() @@ -300,7 +330,7 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Unpack the MLF and copy the relevant files metadata = self._disassemble_mlf(model_library_format_path, source_dir) - shutil.copy2(model_library_format_path, source_dir / "model") + shutil.copy2(model_library_format_path, project_dir / MODEL_LIBRARY_FORMAT_RELPATH) # For AOT, template model.h with metadata to minimize space usage if options["project_type"] == "example_project": diff --git a/apps/microtvm/ethosu/Makefile b/apps/microtvm/ethosu/Makefile index 65cf6524bc0c..370799972de6 100644 --- a/apps/microtvm/ethosu/Makefile +++ b/apps/microtvm/ethosu/Makefile @@ -28,7 +28,7 @@ CMSIS_PATH ?= ${ETHOSU_PATH}/cmsis ETHOSU_PLATFORM_PATH ?= ${ETHOSU_PATH}/core_platform CORSTONE_300_PATH = ${ETHOSU_PLATFORM_PATH}/targets/corstone-300 PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=cortex-m55 -mthumb -mfloat-abi=hard -std=gnu99 -CMAKE = cmake +CMAKE ?= cmake CC = arm-none-eabi-gcc AR = arm-none-eabi-ar RANLIB = arm-none-eabi-ranlib diff --git a/apps/microtvm/ethosu/run_demo.sh b/apps/microtvm/ethosu/run_demo.sh index de33bfe8d427..5d9efb359b24 100755 --- a/apps/microtvm/ethosu/run_demo.sh +++ b/apps/microtvm/ethosu/run_demo.sh @@ -32,6 +32,10 @@ Usage: run_demo.sh [--ethosu_driver_path ETHOSU_DRIVER_PATH] Set path to CMSIS. --ethosu_platform_path ETHOSU_PLATFORM_PATH Set path to Arm(R) Ethos(TM)-U core platform. +--fvp_path FVP_PATH + Set path to FVP. +--cmake_path + Set path to cmake. EOF } @@ -79,6 +83,30 @@ while (( $# )); do fi ;; + --fvp_path) + if [ $# -gt 1 ] + then + export PATH="$2/models/Linux64_GCC-6.4:$PATH" + shift 2 + else + echo 'ERROR: --fvp_path requires a non-empty argument' >&2 + show_usage >&2 + exit 1 + fi + ;; + + --cmake_path) + if [ $# -gt 1 ] + then + export CMAKE="$2" + shift 2 + else + echo 'ERROR: --cmake_path requires a non-empty argument' >&2 + show_usage >&2 + exit 1 + fi + ;; + -*|--*) echo "Error: Unknown flag: $1" >&2 show_usage >&2 @@ -100,8 +128,10 @@ mobilenet_url='https://storage.googleapis.com/download.tensorflow.org/models/mob curl --retry 64 -sSL ${mobilenet_url} | gunzip | tar -xvf - ./mobilenet_v1_1.0_224_quant.tflite # Compile model for Arm(R) Cortex(R)-M55 CPU and Ethos(TM)-U55 NPU -tvmc compile --target="ethos-u -accelerator_config=ethos-u55-256, \ - c -runtime=c --link-params -mcpu=cortex-m55 --executor=aot --interface-api=c --unpacked-api=1" \ +# An alternative to using "python3 -m tvm.driver.tvmc" is to call +# "tvmc" directly once TVM has been pip installed. +python3 -m tvm.driver.tvmc compile --target="ethos-u -accelerator_config=ethos-u55-256, \ + c -runtime=c --link-params -mcpu=cortex-m55 -executor=aot -interface-api=c -unpacked-api=1" \ --pass-config tir.disable_vectorize=1 ./mobilenet_v1_1.0_224_quant.tflite --output-format=mlf tar -xvf module.tar diff --git a/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh index 11d89f2cd44e..2724069ba722 100644 --- a/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh +++ b/apps/microtvm/reference-vm/arduino/base-box/base_box_provision.sh @@ -31,8 +31,12 @@ cd ~ sudo apt-get install -y ca-certificates # Install Arduino-CLI (specific version) +# To keep in sync with the version +# defined in apps/microtvm/arduino/template_project/microtvm_api_server.py +ARDUINO_CLI_VERSION="0.18.3" + export PATH="/home/vagrant/bin:$PATH" -wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s 0.18.3 +wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s ${ARDUINO_CLI_VERSION} # Arduino (the CLI and GUI) require the dialout permission for uploading sudo usermod -a -G dialout $USER diff --git a/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh index 0631e89f3bb3..0e83d1b8be97 100644 --- a/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh +++ b/apps/microtvm/reference-vm/zephyr/base-box/base_box_provision.sh @@ -28,7 +28,8 @@ source ~/.profile # Init Zephyr cd ~ # Using most recent commit that passes all the tests. -~/ubuntu_init_zephyr_project.sh ~/zephyr v2.5-branch --commit dabf23758417fd041fec2a2a821d8f526afac29d +ZEPHYR_VERSION="v2.5-branch" +~/ubuntu_init_zephyr_project.sh ~/zephyr ${ZEPHYR_VERSION} --commit dabf23758417fd041fec2a2a821d8f526afac29d # Cleanup rm -f *.sh diff --git a/apps/microtvm/zephyr/template_project/boards.json b/apps/microtvm/zephyr/template_project/boards.json index aabed3322150..18e393897f04 100644 --- a/apps/microtvm/zephyr/template_project/boards.json +++ b/apps/microtvm/zephyr/template_project/boards.json @@ -39,7 +39,7 @@ "board": "qemu_riscv32", "model": "host", "is_qemu": true, - "fpu": true + "fpu": false }, "qemu_riscv64": { "board": "qemu_riscv64", diff --git a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h index f8fc7514a28d..c3beaed522f2 100644 --- a/apps/microtvm/zephyr/template_project/crt_config/crt_config.h +++ b/apps/microtvm/zephyr/template_project/crt_config/crt_config.h @@ -36,7 +36,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 @@ -48,7 +48,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index f700b5774c72..7e13f928b288 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -61,6 +61,11 @@ BOARDS = API_SERVER_DIR / "boards.json" +# Used to check Zephyr version installed on the host. +# We only check two levels of the version. +ZEPHYR_VERSION = 2.5 + + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. try: @@ -265,6 +270,15 @@ def _get_nrf_device_args(options): "config_main_stack_size", help="Sets CONFIG_MAIN_STACK_SIZE for Zephyr board.", ), + server.ProjectOption( + "warning_as_error", + choices=(True, False), + help="Treat warnings as errors and raise an Exception.", + ), + server.ProjectOption( + "compile_definitions", + help="Extra definitions added project compile.", + ), ] @@ -342,7 +356,27 @@ def _create_prj_conf(self, project_dir, options): "aot_demo": "memory microtvm_rpc_common common", } + def _get_platform_version(self) -> float: + with open(pathlib.Path(os.getenv("ZEPHYR_BASE")) / "VERSION", "r") as f: + lines = f.readlines() + for line in lines: + line = line.replace(" ", "").replace("\n", "").replace("\r", "") + if "VERSION_MAJOR" in line: + version_major = line.split("=")[1] + if "VERSION_MINOR" in line: + version_minor = line.split("=")[1] + + return float(f"{version_major}.{version_minor}") + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): + # Check Zephyr version + version = self._get_platform_version() + if version != ZEPHYR_VERSION: + message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}." + if options.get("warning_as_error") is not None and options["warning_as_error"]: + raise server.ServerError(message=message) + _LOG.warning(message) + project_dir = pathlib.Path(project_dir) # Make project directory. project_dir.mkdir() @@ -389,6 +423,11 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec cmake_f.write(line) + if options.get("compile_definitions"): + flags = options.get("compile_definitions") + for item in flags: + cmake_f.write(f"target_compile_definitions(app PUBLIC {item})\n") + self._create_prj_conf(project_dir, options) # Populate crt-config.h diff --git a/apps/microtvm/zephyr/template_project/src/aot_demo/main.c b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c index a96e3b4d0a4e..3946727b26a8 100644 --- a/apps/microtvm/zephyr/template_project/src/aot_demo/main.c +++ b/apps/microtvm/zephyr/template_project/src/aot_demo/main.c @@ -38,14 +38,21 @@ #include "posix_board_if.h" #endif -#define WORKSPACE_SIZE (270 * 1024) +// WORKSPACE_SIZE defined in Project API Makefile static uint8_t g_aot_memory[WORKSPACE_SIZE]; tvm_workspace_t app_workspace; -// Wakeup sequence used to wake up QEMU on the host. -const unsigned char g_wakeup_sequence[] = "#wakeup\n"; -const char g_start_cmd[] = "start\n"; +// Transport Commands. +// Commands on host end with `\n` +// Commands on microTVM device end with `%` +const unsigned char CMD_WAKEUP[] = "wakeup\n"; +const unsigned char CMD_READY[] = "ready\n"; +const unsigned char CMD_INIT[] = "init"; +const unsigned char CMD_INFER[] = "infer"; + +#define CMD_SIZE 80u +#define CMD_TERMINATOR '%' size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, va_list args) { @@ -163,35 +170,10 @@ int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { } static uint8_t main_rx_buf[128]; -static uint8_t cmd_buf[128]; +static uint8_t g_cmd_buf[128]; static size_t g_cmd_buf_ind; -void main(void) { - g_cmd_buf_ind = 0; - memset((char*)cmd_buf, 0, sizeof(cmd_buf)); - TVMPlatformUARTInit(); - k_timer_init(&g_microtvm_timer, NULL, NULL); - // Wake up host side. - TVMPlatformWriteSerial(g_wakeup_sequence, sizeof(g_wakeup_sequence)); - - // Wait for start command - while (true) { - int bytes_read = TVMPlatformUartRxRead(main_rx_buf, sizeof(main_rx_buf)); - if (bytes_read > 0) { - memcpy((char*)cmd_buf + g_cmd_buf_ind, main_rx_buf, bytes_read); - g_cmd_buf_ind += bytes_read; - } - if (g_cmd_buf_ind >= 6) { - if (!strcmp((char*)(cmd_buf), g_start_cmd)) { - break; - } else { - memset((char*)cmd_buf, 0, sizeof(cmd_buf)); - g_cmd_buf_ind = 0; - } - } - } - TVMLogf("Zephyr AOT Runtime\n"); - +void TVMInfer() { struct tvmgen_default_inputs inputs = { .input_1 = input_data, }; @@ -219,7 +201,47 @@ void main(void) { max_val = output_data[i]; } } - TVMLogf("#result:%d:%d\n", max_ind, (uint32_t)(elapsed_time * 1000)); + TVMLogf("result:%d:%d\n", max_ind, (uint32_t)(elapsed_time * 1000)); +} + +// Execute functions based on received command +void command_ready(char* command) { + if (strncmp(command, CMD_INIT, CMD_SIZE) == 0) { + TVMPlatformWriteSerial(CMD_WAKEUP, sizeof(CMD_WAKEUP)); + } else if (strncmp(command, CMD_INFER, CMD_SIZE) == 0) { + TVMInfer(); + } else { + TVMPlatformWriteSerial(CMD_READY, sizeof(CMD_READY)); + } +} + +// Append received characters to buffer and check for termination character. +void serial_callback(char* message, int len_bytes) { + for (int i = 0; i < len_bytes; i++) { + if (message[i] == CMD_TERMINATOR) { + g_cmd_buf[g_cmd_buf_ind] = (char)0; + command_ready(g_cmd_buf); + g_cmd_buf_ind = 0; + } else { + g_cmd_buf[g_cmd_buf_ind] = message[i]; + g_cmd_buf_ind += 1; + } + } +} + +void main(void) { + g_cmd_buf_ind = 0; + memset((char*)g_cmd_buf, 0, sizeof(g_cmd_buf)); + TVMPlatformUARTInit(); + k_timer_init(&g_microtvm_timer, NULL, NULL); + + while (true) { + int bytes_read = TVMPlatformUartRxRead(main_rx_buf, sizeof(main_rx_buf)); + if (bytes_read > 0) { + serial_callback(main_rx_buf, bytes_read); + } + } + #ifdef CONFIG_ARCH_POSIX posix_exit(0); #endif diff --git a/apps/microtvm/zephyr/template_project/src/host_driven/main.c b/apps/microtvm/zephyr/template_project/src/host_driven/main.c index 43064e804193..44d656028cbc 100644 --- a/apps/microtvm/zephyr/template_project/src/host_driven/main.c +++ b/apps/microtvm/zephyr/template_project/src/host_driven/main.c @@ -260,11 +260,6 @@ void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) { // The main function of this application. extern void __stdout_hook_install(int (*hook)(int)); void main(void) { - // TODO (mehrdadh): Update this when zephyr version has updated to 2.6. - // Update zephyr to latest version to use with qemu_riscv32. -#ifdef CONFIG_BOARD_QEMU_RISCV32 - k_float_enable(_current, 0); -#endif #ifdef CONFIG_LED int ret; diff --git a/apps/pt_tvmdsoop/CMakeLists.txt b/apps/pt_tvmdsoop/CMakeLists.txt new file mode 100644 index 000000000000..05b3b0babc01 --- /dev/null +++ b/apps/pt_tvmdsoop/CMakeLists.txt @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +cmake_minimum_required(VERSION 3.2) +project(pt_tvmdsoop C CXX) + +set(BUILD_PT_TVMDSOOP_ONLY ON) +set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) +set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) + +include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dlpack/include/) +include_directories(SYSTEM ${TVM_ROOT}/3rdparty/dmlc-core/include/) +include_directories(${TVM_ROOT}/include) + +link_directories(${TVM_ROOT}/build) + +include(${TVM_ROOT}/cmake/utils/Utils.cmake) +include(${TVM_ROOT}/cmake/utils/FindCUDA.cmake) +include(${TVM_ROOT}/cmake/modules/CUDA.cmake) + +include(${TVM_ROOT}/cmake/modules/contrib/PT_TVMDSOOP.cmake) diff --git a/apps/pt_tvmdsoop/prepare_and_test_pt_tvm_class.sh b/apps/pt_tvmdsoop/prepare_and_test_pt_tvm_class.sh new file mode 100755 index 000000000000..666f774017c8 --- /dev/null +++ b/apps/pt_tvmdsoop/prepare_and_test_pt_tvm_class.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(cd $(dirname $0)/../..; pwd) +echo "TVM_ROOT=${TVM_ROOT}" + +export PYTHONPATH=${TVM_ROOT}/python + +if [ ! -f $TVM_ROOT/build/libtvm.so ]; then + echo "$TVM_ROOT/build/libtvm.so missing" + exit 1 +fi + +if [ ! -f $TVM_ROOT/build/libtvm_runtime.so ]; then + echo "$TVM_ROOT/build/libtvm_runtime.so missing" + exit 1 +fi + +python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1 + +if [ "$?" -eq 0 ]; then + echo "Build PT_TVMDSOOP with gpu support and execute tests" + CMAKE_OPTIONS="-DUSE_CUDA=ON -DUSE_CUDNN=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}" + mkdir -p build + cd build; cmake .. ${CMAKE_OPTIONS} && make + cp *.so $TVM_ROOT/build/ + cd .. + + LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests +fi + diff --git a/apps/pt_tvmdsoop/tests/test_torch_compile_cpu.py b/apps/pt_tvmdsoop/tests/test_torch_compile_cpu.py new file mode 100644 index 000000000000..5ad88b45dc80 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_torch_compile_cpu.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch module""" +import torch +import time +import tvm +from tvm.contrib.torch import compile + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return x * x + + +model = Model() +x = torch.rand([1, 3, 224, 224]) +model_jit = torch.jit.trace(model, x) +print(model_jit.graph) + +print("run torchscript...") +for i in range(20): + t = time.time() + model_jit(x) + print(time.time() - t) + + +option = { + "input_infos": [ + ("x", (1, 3, 224, 224)), + ], + "default_dtype": "float16", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 1, # set zero to skip tuning + "tuning_log_file": "tuning.log", + "target": "llvm", + "device": tvm.cpu(), +} + +pytorch_tvm_module = compile(model_jit, option) +torch.jit.script(pytorch_tvm_module).save("model_tvm.pt") + + +print("Run PyTorch...") +for i in range(20): + t = time.time() + outputs = pytorch_tvm_module.forward([x.cpu()]) + print(1000 * (time.time() - t)) +print(outputs[0].shape) diff --git a/apps/pt_tvmdsoop/tests/test_torch_compile_gpu.py b/apps/pt_tvmdsoop/tests/test_torch_compile_gpu.py new file mode 100644 index 000000000000..b2ceb7f5cd6b --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_torch_compile_gpu.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch module""" +import torch +import time +from torchvision.models import resnet50 +import tvm +from tvm.contrib.torch import compile + + +model = resnet50().half().cuda() +x = torch.rand([1, 3, 224, 224]).half().cuda() +model_jit = torch.jit.trace(model, x) +print(model_jit.graph) + +print("run torchscript...") +for i in range(20): + t = time.time() + model_jit(x) + torch.cuda.synchronize() + print(time.time() - t) + + +option = { + "input_infos": [ + ("x", (1, 3, 224, 224)), + ], + "default_dtype": "float16", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 1, # set zero to skip tuning + "tuning_log_file": "tuning.log", + "target": "cuda", + "device": tvm.cuda(0), +} + +pytorch_tvm_module = compile(model_jit, option) +torch.jit.script(pytorch_tvm_module).save("model_tvm.pt") + + +print("Run PyTorch...") +for i in range(20): + t = time.time() + outputs = pytorch_tvm_module.forward([x]) + torch.cuda.synchronize() + print(1000 * (time.time() - t)) +print(outputs[0].shape) diff --git a/apps/pt_tvmdsoop/tests/test_torch_graph_module.py b/apps/pt_tvmdsoop/tests/test_torch_graph_module.py new file mode 100644 index 000000000000..4e3b51227cbe --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_torch_graph_module.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch module""" +import tempfile +import os +import logging +import torch +import numpy as np +import tvm +import tvm.testing +from tvm import te, relay +import tvm.contrib.torch +from tvm.contrib import graph_runtime + +TVM_ASSETS = ["mod.so", "graph.json", "params"] + + +def test_use_pt_graph_module(): + """main test function""" + + def build_export_graph(device): + """relay build & export graph""" + x = relay.var("x", shape=(10, 5)) + y = relay.var("y", shape=(1, 5)) + z = relay.add(x, y) + z = relay.exp(z) + func = relay.Function([x, y], z) + x_data = np.random.rand(10, 5).astype("float32") + y_data = np.random.rand(1, 5).astype("float32") + params = {"y": y_data} + + pt_device = torch.device(device) + if pt_device.type == "cuda": + target = "cuda" + ctx = tvm.cuda(pt_device.index) + else: + target = "llvm" + ctx = tvm.cpu(0) + + graph, lib, params = relay.build(tvm.IRModule.from_expr(func), target=target, params=params) + mod = graph_runtime.create(graph, lib, device=ctx) + mod.set_input(**params) + mod.set_input(x=x_data) + mod.run() + res = mod.get_output(0).asnumpy() + ref_res = np.exp(y_data + x_data) + tvm.testing.assert_allclose(res, ref_res, atol=1e-5, rtol=1e-5) + + # export to tempdir + export_dir = tempfile.mkdtemp("tvm_export") + lib.export_library(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "w") as fout: + fout.write(graph) + with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout: + fout.write(relay.save_param_dict(params)) + + return export_dir + + def test_pt_run(device, trace=True, to_device=None): + """test add lib with Pytorch wrapper""" + print("\n############## Test on device:", device, "#################") + export_dir = build_export_graph(device) + engine = tvm.contrib.torch.GraphModule(num_inputs=2, num_outputs=1).to(device) + + x = np.random.rand(10, 5).astype("float32") + y = np.random.rand(1, 5).astype("float32") + + expect = np.exp(y + x) + + def get_inputs_by_device(device): + inps = [torch.Tensor(x), torch.Tensor(y)] + if device == "cpu": + return inps + else: + device_type, device_id = device.split(":") + assert device_type == "cuda" + return [inp.cuda(int(device_id)) for inp in inps] + + assets = [os.path.join(export_dir, i) for i in TVM_ASSETS] + engine.init((x.shape, y.shape), *assets) + + outputs = engine.forward(get_inputs_by_device(device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + + if trace: + print("\n################ Test trace and load #################") + scripted = torch.jit.script(engine) + scripted_dir = tempfile.mkdtemp("scripted") + scripted_path = os.path.join(scripted_dir, "model.pt") + scripted.save(scripted_path) + loaded = torch.jit.load(scripted_path) + outputs = loaded.forward(get_inputs_by_device(device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + del scripted + del loaded + + if to_device: + print( + "\n################ Test move from [{}] to [{}] #################".format( + device, to_device + ) + ) + engine = engine.to(to_device) + outputs = engine.forward(get_inputs_by_device(to_device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + del engine + + test_pt_run(device="cuda:0", trace=True, to_device="cuda:1") + test_pt_run(device="cpu", trace=True) + + +if __name__ == "__main__": + test_use_pt_graph_module() diff --git a/apps/pt_tvmdsoop/tests/test_torch_script.py b/apps/pt_tvmdsoop/tests/test_torch_script.py new file mode 100644 index 000000000000..34b959714a18 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_torch_script.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch module""" +import os +import torch +import time +import numpy as np +import tvm +import tvm.testing +import tempfile +from tvm.contrib.torch import PyTorchTVMModule, compile + + +class Model(torch.nn.Module): + def forward(self, x, y): + return torch.matmul(x, y.softmax(1)) + + +model = Model() +model.cuda().half() +x = torch.rand([1280, 2464, 4]).cuda().half() +y = torch.rand([1280, 4, 1]).cuda().half() +for i in range(20): + t = time.time() + o = model(x, y) + torch.cuda.synchronize() + print(1000 * (time.time() - t)) +print(o.shape) + + +model_jit = torch.jit.script(model) +print(model_jit.graph) +input_shapes = [("x", list(x.shape)), ("y", list(y.shape))] +dtype = "float16" +export_dir = tempfile.mkdtemp("pytorch_compiled") +print("tmp export_dir:", export_dir) + + +mod = PyTorchTVMModule() +print("Converting...") +mod.from_pytorch(model_jit, input_shapes, dtype) + +log_file = os.path.join(export_dir, "tuning.log") +if not os.path.exists(log_file): + print("Tuning...") + mod.tune_tvm(log_file=log_file, n_trial=20) + +print("Building...") +tvm_mod = mod.build_tvm(export_dir) +pytorch_mod = mod.build_pytorch_module(num_inputs=2, num_outputs=1) + + +## Or you can load from a prebuilt tvm module +# mod = PyTorchTVMModule() +# tvm_mod = mod.load_tvm(export_dir) +# pytorch_mod = mod.build_pytorch_module(num_inputs=2, num_outputs=1, input_infos=input_shapes) + + +print("Run TVM...") +tvm_x = tvm.nd.array(x.cpu().numpy().astype(dtype), device=tvm.gpu(0)) +tvm_y = tvm.nd.array(y.cpu().numpy().astype(dtype), device=tvm.gpu(0)) +for i in range(20): + t = time.time() + tvm_mod.run(x=tvm_x, y=tvm_y) + print(1000 * (time.time() - t)) +tvm_output = tvm_mod.get_output(0) +print(tvm_output.shape) + + +print("Run PyTorch...") +for i in range(20): + t = time.time() + outputs = pytorch_mod.forward([x, y]) + torch.cuda.synchronize() + print(1000 * (time.time() - t)) +print(outputs[0].shape) + + +class EnsembleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.jit.script(pytorch_mod) + + def forward(self, x, y, z) -> torch.Tensor: + if x > 1: + out = self.layer(y, z)[0] + else: + out = torch.ones([1280, 2464, 1]) + return out + + +print("Exporting...") +scripted = torch.jit.script(EnsembleModel()) +print(scripted.graph) +scripted_path = os.path.join(export_dir, "model_tvm.pt") +scripted.save(scripted_path) + + +# print(o == outputs[0]) +# print(o - outputs[0]) diff --git a/apps/pt_tvmdsoop/tests/test_torch_vm_module.py b/apps/pt_tvmdsoop/tests/test_torch_vm_module.py new file mode 100644 index 000000000000..81d9dadb02c1 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_torch_vm_module.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch vm module""" +import tempfile +import os +import logging +import torch +import numpy as np +import tvm +from tvm.contrib.torch.pytorch_tvm import TVM_ASSETS +import tvm.testing +from tvm import te, relay +import tvm.contrib.torch +from tvm.contrib import graph_runtime + +TVM_ASSETS = ["mod.so", "code.ro"] + + +def test_use_pt_vm_module(): + """main test function""" + + def build_export_vm(device): + """relay build & export graph""" + x = relay.var("x", shape=(10, 5)) + y = relay.var("y", shape=(1, 5)) + z = relay.add(x, y) + z = relay.exp(z) + func = relay.Function([x, y], z) + x_data = np.random.rand(10, 5).astype("float32") + y_data = np.random.rand(1, 5).astype("float32") + + pt_device = torch.device(device) + if pt_device.type == "cuda": + target = "cuda" + ctx = tvm.cuda(pt_device.index) + else: + target = "llvm" + ctx = tvm.cpu(0) + exe = relay.vm.compile(tvm.IRModule.from_expr(func), target=target, params={}) + code, lib = exe.save() + export_dir = tempfile.mkdtemp("tvm_export") + # export to tempdir + lib.export_library(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "wb") as fout: + fout.write(code) + vm = tvm.runtime.vm.VirtualMachine(exe, ctx) + res = vm.run(x_data, y_data) + ref_res = np.exp(y_data + x_data) + tvm.testing.assert_allclose(res.numpy(), ref_res, atol=1e-5, rtol=1e-5) + return export_dir + + def test_pt_run(device, trace=True, to_device=None, inp_on_cuda=False): + """test add lib with Pytorch wrapper""" + print("\n############## Test on device:", device, "#################") + export_dir = build_export_vm(device) + engine = tvm.contrib.torch.VMModule(num_inputs=2, num_outputs=1).to(device) + + x = np.random.rand(10, 5).astype("float32") + y = np.random.rand(1, 5).astype("float32") + + expect = np.exp(y + x) + + def get_inputs_by_device(device): + inps = [torch.Tensor(x), torch.Tensor(y)] + if device == "cpu": + return inps + else: + device_type, device_id = device.split(":") + assert device_type == "cuda" + return [inp.cuda(int(device_id)) for inp in inps] + + assets = [os.path.join(export_dir, i) for i in TVM_ASSETS] + engine.init((x.shape, y.shape), *assets) + + outputs = engine.forward(get_inputs_by_device(device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + + if trace: + print("\n################ Test trace and load #################") + scripted = torch.jit.script(engine) + scripted_dir = tempfile.mkdtemp("scripted") + scripted_path = os.path.join(scripted_dir, "model.pt") + scripted.save(scripted_path) + loaded = torch.jit.load(scripted_path) + outputs = loaded.forward(get_inputs_by_device(device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + del scripted + del loaded + + if to_device: + print( + "\n################ Test move from [{}] to [{}] #################".format( + device, to_device + ) + ) + engine = engine.to(to_device) + outputs = engine.forward(get_inputs_by_device(to_device)) + tvm.testing.assert_allclose(outputs[0].cpu(), expect, atol=1e-5, rtol=1e-5) + del engine + + test_pt_run(device="cuda:0", trace=True, to_device="cuda:1", inp_on_cuda=True) + test_pt_run(device="cpu", trace=True, inp_on_cuda=False) + + +if __name__ == "__main__": + test_use_pt_vm_module() diff --git a/apps/pt_tvmdsoop/tests/test_trace_tvm_module.py b/apps/pt_tvmdsoop/tests/test_trace_tvm_module.py new file mode 100644 index 000000000000..0a12ec529fa0 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_trace_tvm_module.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for torch module""" +import torch +import time +import tvm +from tvm.contrib.torch import compile, TraceTvmModule, pytorch_tvm + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x * y + + +model = Model() +x = torch.rand([1, 2, 3]) +y = torch.rand([1, 2, 3]) +model_jit = torch.jit.script(model) + +option = { + "input_infos": [("x", (1, 2, 3)), ("y", (1, 2, 3))], + "default_dtype": "float32", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 0, # set zero to skip tuning + "tuning_log_file": "tuning.log", + "target": "llvm", + "device": tvm.cpu(), +} + +# use TraceTvmModule to convert List[Tensor] input/output +# to tuple of Tensors +pytorch_tvm_module = compile(model_jit, option) +scripted = torch.jit.script(pytorch_tvm_module) +traced = torch.jit.trace(TraceTvmModule(scripted), (x, y)) + +res_traced = traced.forward(x, y) +res_expected = pytorch_tvm_module.forward([x, y])[0] +tvm.testing.assert_allclose(res_traced, res_expected) diff --git a/cmake/config.cmake b/cmake/config.cmake index 1fce11f90aed..30c0eab60ea5 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -272,6 +272,9 @@ set(USE_THRUST OFF) # Whether to build the TensorFlow TVMDSOOp module set(USE_TF_TVMDSOOP OFF) +# Whether to build the PyTorch custom class module +set(USE_PT_TVMDSOOP OFF) + # Whether to use STL's std::unordered_map or TVM's POD compatible Map set(USE_FALLBACK_STL_MAP OFF) @@ -347,3 +350,7 @@ set(USE_PAPI OFF) # Note that cmake will use `find_package` to find GTest. Please use cmake's # predefined variables to specify the path to the GTest package if needed. set(USE_GTEST AUTO) + +# Enable using CUTLASS as a BYOC backend +# Need to have USE_CUDA=ON +set(USE_CUTLASS OFF) diff --git a/cmake/modules/Arduino.cmake b/cmake/modules/Arduino.cmake new file mode 100644 index 000000000000..54c144081efa --- /dev/null +++ b/cmake/modules/Arduino.cmake @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. The ASF licenses this +# file to you under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +if(USE_MICRO) + message(STATUS "Add Arduino for microTVM") + + function(microtvm_add_arduino) + list( + APPEND + ARDUINO_FILE_COPY_JOBS + "apps/microtvm/arduino/template_project microtvm_api_server.py -> arduino" + "apps/microtvm/arduino/template_project boards.json -> arduino" + "apps/microtvm/arduino/template_project/src/example_project *.c -> arduino/src/example_project" + "apps/microtvm/arduino/template_project/src/example_project *.h -> arduino/src/example_project" + "apps/microtvm/arduino/template_project/src/example_project *.ino -> arduino/src/example_project" + "apps/microtvm/arduino/template_project/src/host_driven *.c -> arduino/src/host_driven" + "apps/microtvm/arduino/template_project/src/host_driven *.ino -> arduino/src/host_driven" + "apps/microtvm/arduino/template_project/crt_config *.h -> arduino/crt_config" + ) + + foreach(job_spec IN LISTS ARDUINO_FILE_COPY_JOBS) + string(REPLACE " " ";" job_spec "${job_spec}") + list(LENGTH job_spec job_spec_length) + math(EXPR job_spec_length_mod "${job_spec_length} % 3") + if(NOT "${job_spec_length_mod}" EQUAL 1) + message( + FATAL_ERROR + "Arduino copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}" + ) + endif() + math(EXPR job_spec_stop "${job_spec_length} - 3") + + list(GET job_spec 0 job_src_base) + set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}") + foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3) + list(GET job_spec ${copy_pattern_index} copy_pattern) + math(EXPR copy_dest_index "${copy_pattern_index} + 2") + list(GET job_spec ${copy_dest_index} copy_dest) + + file( + GLOB_RECURSE copy_files + RELATIVE "${job_src_base}" + "${job_src_base}/${copy_pattern}") + list(LENGTH copy_files copy_files_length) + if("${copy_files_length}" EQUAL 0) + message( + FATAL_ERROR + "Arduino copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}" + ) + endif() + foreach(copy_src IN LISTS copy_files) + get_filename_component( + dest_path "${MICROTVM_TEMPLATE_PROJECTS}/${copy_dest}/${copy_src}" + ABSOLUTE) + tvm_micro_add_copy_file(arduino_template_deps + ${job_src_base}/${copy_src} ${dest_path}) + endforeach() + endforeach() + endforeach() + + add_custom_target(arduino DEPENDS ${arduino_template_deps}) + endfunction() + + microtvm_add_arduino() + +endif(USE_MICRO) diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 1491a4558611..1ae250f1bee3 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -53,18 +53,22 @@ if(BUILD_FOR_HEXAGON) include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} ${HEXAGON_QURT_INCLUDES}) endif() -if(USE_HEXAGON_LAUNCHER STREQUAL "ON") - set(USE_HEXAGON_DEVICE "${PICK_SIM}") -else() - if(USE_HEXAGON_DEVICE STREQUAL "OFF") - list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) - return() - elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND - NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") - set(ERROR_MSG - "USE_HEXAGON_DEVICE must be one of [${PICK_NONE}|${PICK_SIM}|${PICK_HW}]") - message(SEND_ERROR "${ERROR_MSG}") - return() +# Don't run these checks when compiling Hexagon device code, +# e.g. when compiling the TVM runtime for Hexagon. +if (NOT BUILD_FOR_HEXAGON) + if(USE_HEXAGON_LAUNCHER STREQUAL "ON") + set(USE_HEXAGON_DEVICE "${PICK_SIM}") + else() + if(USE_HEXAGON_DEVICE STREQUAL "OFF") + list(APPEND COMPILER_SRCS src/target/opt/build_hexagon_off.cc) + return() + elseif(NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}" AND + NOT USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") + set(ERROR_MSG + "USE_HEXAGON_DEVICE must be one of [${PICK_NONE}|${PICK_SIM}|${PICK_HW}]") + message(SEND_ERROR "${ERROR_MSG}") + return() + endif() endif() endif() @@ -76,7 +80,6 @@ if(NOT USE_HEXAGON_SDK) endif() if(USE_HEXAGON_LAUNCHER STREQUAL "ON") - if(DEFINED USE_ANDROID_TOOLCHAIN) if(NOT DEFINED ANDROID_PLATFORM) message(SEND_ERROR "Please set ANDROID_PLATFORM " @@ -91,7 +94,7 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") " launcher for hexagon.") endif() - set(LAUNCHER_BINARY_DIR "${CMAKE_BINARY_DIR}/launcher") + set(LAUNCHER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/apps_hexagon_launcher") ExternalProject_Add(launcher_android SOURCE_DIR "${CMAKE_SOURCE_DIR}/apps/hexagon_launcher/cmake/android" INSTALL_DIR "${LAUNCHER_BINARY_DIR}" @@ -101,14 +104,15 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DANDROID_ABI=${ANDROID_ABI}" "-DFASTRPC_LIBS=STUB" - "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DCMAKE_INSTALL_PREFIX:PATH=" INSTALL_COMMAND "" ) ExternalProject_Get_Property(launcher_android BINARY_DIR) ExternalProject_Add_Step(launcher_android copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/launcher_android ${BINARY_DIR}/libtvm_runtime.so + ${LAUNCHER_BINARY_DIR} DEPENDEES install ) ExternalProject_Add(launcher_hexagon @@ -119,14 +123,15 @@ if(USE_HEXAGON_LAUNCHER STREQUAL "ON") "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++" "-DFASTRPC_LIBS=SKEL" - "-DUSE_HEXAGON_ARCH=v68" + "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DCMAKE_INSTALL_PREFIX:PATH=" INSTALL_COMMAND "" ) ExternalProject_Get_Property(launcher_hexagon BINARY_DIR) ExternalProject_Add_Step(launcher_hexagon copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_directory ${BINARY_DIR} ${LAUNCHER_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${BINARY_DIR}/liblauncher_rpc_skel.so + ${LAUNCHER_BINARY_DIR} DEPENDEES install ) @@ -136,12 +141,12 @@ endif() if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") find_hexagon_toolchain() message(STATUS "Hexagon toolchain: ${HEXAGON_TOOLCHAIN}") - file(GLOB RUNTIME_HEXAGON_SIM_SRCS src/runtime/hexagon/sim/*.cc) + file(GLOB RUNTIME_HEXAGON_SIM_SRCS src/runtime/hexagon/android/sim/*.cc) include_directories(SYSTEM "${HEXAGON_TOOLCHAIN}/include/iss") link_directories("${HEXAGON_TOOLCHAIN}/lib/iss") list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") ExternalProject_Add(sim_dev - SOURCE_DIR "${CMAKE_SOURCE_DIR}/src/runtime/hexagon/sim/driver" + SOURCE_DIR "${CMAKE_SOURCE_DIR}/src/runtime/hexagon/android/sim/driver" CMAKE_ARGS "-DCMAKE_C_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" @@ -151,7 +156,7 @@ if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") find_hexagon_sdk_root("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}") find_hexagon_toolchain() - file(GLOB RUNTIME_HEXAGON_DEVICE_SRCS src/runtime/hexagon/target/*.cc) + file(GLOB RUNTIME_HEXAGON_DEVICE_SRCS src/runtime/hexagon/android/target/*.cc) include_directories(SYSTEM ${HEXAGON_SDK_INCLUDES} @@ -165,7 +170,10 @@ elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") endif() endif() -file(GLOB RUNTIME_HEXAGON_SRCS src/runtime/hexagon/*.cc) +if(BUILD_FOR_HEXAGON AND USE_HEXAGON_DEVICE STREQUAL "${PICK_NONE}") + file(GLOB RUNTIME_HEXAGON_SRCS src/runtime/hexagon/hexagon/*.cc) +else() + file(GLOB RUNTIME_HEXAGON_SRCS src/runtime/hexagon/android/*.cc) +endif() list(APPEND RUNTIME_SRCS ${RUNTIME_HEXAGON_SRCS} ${RUNTIME_HEXAGON_SIM_SRCS} ${RUNTIME_HEXAGON_DEVICE_SRCS}) - diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 163a56dbd1d4..bf548b232512 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -60,6 +60,7 @@ function(add_lib_info src_file) TVM_INFO_INSTALL_DEV="${INSTALL_DEV}" TVM_INFO_HIDE_PRIVATE_SYMBOLS="${HIDE_PRIVATE_SYMBOLS}" TVM_INFO_USE_TF_TVMDSOOP="${USE_TF_TVMDSOOP}" + TVM_INFO_USE_PT_TVMDSOOP="${USE_PT_TVMDSOOP}" TVM_INFO_USE_FALLBACK_STL_MAP="${USE_FALLBACK_STL_MAP}" TVM_INFO_USE_BYODT_POSIT="${USE_BYODT_POSIT}" TVM_INFO_USE_BLAS="${USE_BLAS}" diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 9f79c7da3cdf..5d822844ae34 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -16,20 +16,9 @@ # under the License. if(USE_MICRO) - message(STATUS "Build standalone CRT for micro TVM") + message(STATUS "Build standalone CRT for microTVM") file(GLOB crt_srcs src/runtime/crt/**) - function(tvm_crt_add_copy_file var src dest) - get_filename_component(basename "${src}" NAME) - get_filename_component(dest_parent_dir "${dest}" DIRECTORY) - add_custom_command( - OUTPUT "${dest}" - COMMAND "${CMAKE_COMMAND}" -E copy "${src}" "${dest}" - DEPENDS "${src}") - list(APPEND "${var}" "${dest}") - set("${var}" "${${var}}" PARENT_SCOPE) - endfunction(tvm_crt_add_copy_file) - function(tvm_crt_define_targets) # Build an isolated build directory, separate from the TVM tree. list(APPEND CRT_FILE_COPY_JOBS @@ -83,7 +72,7 @@ if(USE_MICRO) endif() foreach(copy_src IN LISTS copy_files) get_filename_component(dest_path "${standalone_crt_base}/${copy_dest}/${copy_src}" ABSOLUTE) - tvm_crt_add_copy_file(host_isolated_build_deps ${job_src_base}/${copy_src} ${dest_path}) + tvm_micro_add_copy_file(host_isolated_build_deps ${job_src_base}/${copy_src} ${dest_path}) endforeach() endforeach() endforeach() diff --git a/cmake/modules/Zephyr.cmake b/cmake/modules/Zephyr.cmake new file mode 100644 index 000000000000..048240375cd6 --- /dev/null +++ b/cmake/modules/Zephyr.cmake @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. The ASF licenses this +# file to you under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +if(USE_MICRO) + message(STATUS "Add Zephyr for microTVM") + + function(microtvm_add_zephyr) + list( + APPEND + ZEPHYR_FILE_COPY_JOBS + "apps/microtvm/zephyr/template_project microtvm_api_server.py -> zephyr" + "apps/microtvm/zephyr/template_project boards.json -> zephyr" + "apps/microtvm/zephyr/template_project CMakeLists.txt.template -> zephyr" + "apps/microtvm/zephyr/template_project/src/aot_demo *.c -> zephyr/src/aot_demo" + "apps/microtvm/zephyr/template_project/src/aot_demo *.h -> zephyr/src/aot_demo" + "apps/microtvm/zephyr/template_project/src/host_driven *.c -> zephyr/src/host_driven" + "apps/microtvm/zephyr/template_project/qemu-hack * -> zephyr/qemu-hack" + "apps/microtvm/zephyr/template_project/crt_config *.h -> zephyr/crt_config" + ) + + foreach(job_spec IN LISTS ZEPHYR_FILE_COPY_JOBS) + string(REPLACE " " ";" job_spec "${job_spec}") + list(LENGTH job_spec job_spec_length) + math(EXPR job_spec_length_mod "${job_spec_length} % 3") + if(NOT "${job_spec_length_mod}" EQUAL 1) + message( + FATAL_ERROR + "Zephyr copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}" + ) + endif() + math(EXPR job_spec_stop "${job_spec_length} - 3") + + list(GET job_spec 0 job_src_base) + set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}") + foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3) + list(GET job_spec ${copy_pattern_index} copy_pattern) + math(EXPR copy_dest_index "${copy_pattern_index} + 2") + list(GET job_spec ${copy_dest_index} copy_dest) + + file( + GLOB_RECURSE copy_files + RELATIVE "${job_src_base}" + "${job_src_base}/${copy_pattern}") + list(LENGTH copy_files copy_files_length) + if("${copy_files_length}" EQUAL 0) + message( + FATAL_ERROR + "Zephyr copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}" + ) + endif() + foreach(copy_src IN LISTS copy_files) + get_filename_component( + dest_path "${MICROTVM_TEMPLATE_PROJECTS}/${copy_dest}/${copy_src}" + ABSOLUTE) + tvm_micro_add_copy_file(zephyr_template_deps + ${job_src_base}/${copy_src} ${dest_path}) + endforeach() + endforeach() + endforeach() + + add_custom_target(zephyr DEPENDS ${zephyr_template_deps}) + endfunction() + + microtvm_add_zephyr() + +endif(USE_MICRO) diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake new file mode 100644 index 000000000000..10309f0d90b3 --- /dev/null +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_CUDA AND USE_CUTLASS) + file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc) + list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC}) + + message(STATUS "Build with CUTLASS") +endif() diff --git a/cmake/modules/contrib/EthosN.cmake b/cmake/modules/contrib/EthosN.cmake index 6eb5271f91b9..44d2a2a17ace 100644 --- a/cmake/modules/contrib/EthosN.cmake +++ b/cmake/modules/contrib/EthosN.cmake @@ -20,10 +20,6 @@ if(NOT USE_ETHOSN STREQUAL "OFF") find_ethosn(${USE_ETHOSN}) - if(NOT DEFINED TVM_LLVM_VERSION) - message(FATAL_ERROR "Support for offloading to Ethos-N requires LLVM Support") - endif() - if(NOT ETHOSN_FOUND) message(FATAL_ERROR "Cannot find Ethos-N, USE_ETHOSN=" ${USE_ETHOSN}) diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake new file mode 100644 index 000000000000..7ff88693fe4e --- /dev/null +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") + find_package(Python3 COMPONENTS Interpreter Development) + include_directories(${Python3_INCLUDE_DIRS}) + + message(STATUS "Python3_INCLUDE_DIRS: ${Python3_INCLUDE_DIRS}") + + execute_process(COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())" + OUTPUT_VARIABLE PT_PATH + RESULT_VARIABLE PT_STATUS) + if (NOT ${PT_STATUS} EQUAL 0) + message(FATAL_ERROR "Fail to get pytorch path") + endif() + + string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}") + + set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0") + set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") + + if(NOT USE_CUDA STREQUAL "OFF") + add_definitions(-DPT_TVMDSOOP_ENABLE_GPU) + endif() + + + string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}") + separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR}) + separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) + + + set(LIBRARY_NAME pt_tvmdsoop) + file(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc) + add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) + set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) + + if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") + add_dependencies(${LIBRARY_NAME} tvm) + endif() + + target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + +endif() + diff --git a/cmake/utils/Utils.cmake b/cmake/utils/Utils.cmake index 4e6762b14894..44f622126abb 100644 --- a/cmake/utils/Utils.cmake +++ b/cmake/utils/Utils.cmake @@ -75,6 +75,19 @@ function(assign_source_group group) endforeach() endfunction(assign_source_group) +function(tvm_micro_add_copy_file var src dest) + get_filename_component(basename "${src}" NAME) + get_filename_component(dest_parent_dir "${dest}" DIRECTORY) + add_custom_command( + OUTPUT "${dest}" + COMMAND "${CMAKE_COMMAND}" -E copy "${src}" "${dest}" + DEPENDS "${src}") + list(APPEND "${var}" "${dest}") + set("${var}" "${${var}}" PARENT_SCOPE) +endfunction(tvm_micro_add_copy_file) + +set(MICROTVM_TEMPLATE_PROJECTS "${CMAKE_CURRENT_BINARY_DIR}/microtvm_template_projects") + # From cmake documentation: # True if the constant is 1, ON, YES, TRUE, Y, or a non-zero number. # False if the constant is 0, OFF, NO, FALSE, N, IGNORE, NOTFOUND, the empty string, or ends in the suffix -NOTFOUND. diff --git a/docker/install/ubuntu1804_install_llvm.sh b/docker/install/ubuntu1804_install_llvm.sh index 58399d535d92..b4640aa9ae6e 100755 --- a/docker/install/ubuntu1804_install_llvm.sh +++ b/docker/install/ubuntu1804_install_llvm.sh @@ -41,10 +41,22 @@ echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-12 main\ echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-12 main\ >> /etc/apt/sources.list.d/llvm.list +echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-13 main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-13 main\ + >> /etc/apt/sources.list.d/llvm.list + echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic main\ >> /etc/apt/sources.list.d/llvm.list echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic main\ >> /etc/apt/sources.list.d/llvm.list wget -q -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - -apt-get update && apt-get install -y llvm-9 llvm-10 llvm-11 llvm-12 clang-9 libclang-9-dev clang-10 libclang-10-dev clang-11 libclang-11-dev clang-12 libclang-12-dev + +apt-get update && apt-get install -y \ + llvm-9 llvm-10 llvm-11 llvm-12 llvm-13 \ + clang-9 libclang-9-dev \ + clang-10 libclang-10-dev \ + clang-11 libclang-11-dev \ + clang-12 libclang-12-dev \ + clang-13 libclang-13-dev diff --git a/docker/install/ubuntu1804_install_python.sh b/docker/install/ubuntu1804_install_python.sh index 6b4d6fb4f727..693d8f8b99db 100755 --- a/docker/install/ubuntu1804_install_python.sh +++ b/docker/install/ubuntu1804_install_python.sh @@ -28,5 +28,5 @@ apt-get install -y python3-dev python3-setuptools # Install pip cd /tmp && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py -# Pin pip version -pip3 install pip==19.3.1 +# Pin pip and setuptools versions +pip3 install pip==19.3.1 setuptools==58.4.0 diff --git a/docker/install/ubuntu1804_install_python_venv.sh b/docker/install/ubuntu1804_install_python_venv.sh index 86f0dd5e1223..fe234e035573 100755 --- a/docker/install/ubuntu1804_install_python_venv.sh +++ b/docker/install/ubuntu1804_install_python_venv.sh @@ -27,5 +27,5 @@ apt-get install -y python3-dev python3-setuptools python3-venv python3 -mvenv /opt/tvm-venv -# Pin pip version -/opt/tvm-venv/bin/pip3 install pip==19.3.1 +# Pin pip and setuptools versions +/opt/tvm-venv/bin/pip3 install pip==19.3.1 setuptools==58.4.0 diff --git a/docker/install/ubuntu_download_arm_compute_lib_binaries.sh b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh index c68654c75392..5097fad3d0b6 100755 --- a/docker/install/ubuntu_download_arm_compute_lib_binaries.sh +++ b/docker/install/ubuntu_download_arm_compute_lib_binaries.sh @@ -27,17 +27,19 @@ if [ "$architecture_type" != "aarch64" ]; then gcc-aarch64-linux-gnu fi -compute_lib_version="v21.05" +compute_lib_version="v21.08" +compute_lib_variant="arm64-v8a-neon" +compute_lib_full_name="arm_compute-${compute_lib_version}-bin-linux-${compute_lib_variant}" compute_lib_base_url="https://github.com/ARM-software/ComputeLibrary/releases/download/${compute_lib_version}" -compute_lib_file_name="arm_compute-${compute_lib_version}-bin-linux.tar.gz" +compute_lib_file_name="${compute_lib_full_name}.tar.gz" compute_lib_download_url="${compute_lib_base_url}/${compute_lib_file_name}" -target_lib="linux-arm64-v8a-neon" +target_lib="${compute_lib_variant}" # uncomment line below if you need asserts/debug version of the library # target_lib="${target_lib}-asserts" -extract_dir="arm_compute-${compute_lib_version}-bin-linux" +extract_dir="${compute_lib_full_name}" install_path="/opt/acl" tmpdir=$(mktemp -d) diff --git a/docker/install/ubuntu_install_arduino.sh b/docker/install/ubuntu_install_arduino.sh index c374850aa1df..a612261b2a2b 100644 --- a/docker/install/ubuntu_install_arduino.sh +++ b/docker/install/ubuntu_install_arduino.sh @@ -23,8 +23,9 @@ set -o pipefail export DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates -# Install arduino-cli latest version -wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s +ARDUINO_CLI_VERSION="0.18.3" +# Install arduino-cli +wget -O - https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh -s ${ARDUINO_CLI_VERSION} # Install the cores we want to test on arduino-cli core install arduino:mbed_nano diff --git a/docker/install/ubuntu_install_cmake_source.sh b/docker/install/ubuntu_install_cmake_source.sh index f818fba9721b..18335c98c403 100644 --- a/docker/install/ubuntu_install_cmake_source.sh +++ b/docker/install/ubuntu_install_cmake_source.sh @@ -20,8 +20,8 @@ set -e set -u set -o pipefail -v=3.13 -version=3.13.5 +v=3.14 +version=3.14.7 wget https://cmake.org/files/v${v}/cmake-${version}.tar.gz tar xvf cmake-${version}.tar.gz cd cmake-${version} diff --git a/docker/install/ubuntu_install_python.sh b/docker/install/ubuntu_install_python.sh index d3af336491cc..b71398ad5fc8 100755 --- a/docker/install/ubuntu_install_python.sh +++ b/docker/install/ubuntu_install_python.sh @@ -36,5 +36,5 @@ rm -f /usr/bin/python3 && ln -s /usr/bin/python3.6 /usr/bin/python3 # Install pip cd /tmp && wget -q https://bootstrap.pypa.io/get-pip.py && python3.6 get-pip.py -# Pin pip version -pip3 install pip==19.3.1 +# Pin pip and setuptools versions +pip3 install pip==19.3.1 setuptools==58.4.0 diff --git a/docker/install/ubuntu_install_rocm.sh b/docker/install/ubuntu_install_rocm.sh index 0945c582489f..2f28356da3c8 100755 --- a/docker/install/ubuntu_install_rocm.sh +++ b/docker/install/ubuntu_install_rocm.sh @@ -21,10 +21,10 @@ set -u set -o pipefail # Install ROCm cross compilation toolchain. -wget -qO - http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | sudo apt-key add - -echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/debian/ xenial main > /etc/apt/sources.list.d/rocm.list +wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | sudo apt-key add - +echo 'deb [arch=amd64] https://repo.radeon.com/rocm/apt/4.3/ ubuntu main' | sudo tee /etc/apt/sources.list.d/rocm.list apt-get update && apt-get install -y \ rocm-dev \ - lld && \ + lld-12 && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/docker/install/ubuntu_install_vela.sh b/docker/install/ubuntu_install_vela.sh index e75a99d9d563..a880a6f440e8 100644 --- a/docker/install/ubuntu_install_vela.sh +++ b/docker/install/ubuntu_install_vela.sh @@ -20,7 +20,6 @@ set -e set -u set -o pipefail -pip3 install -U setuptools # In a refactor between v2.1.1 and v3.0.0, find_block_configs was removed from Vela. # Since this is still required for the TVM port, it will be reinstated in Vela in a future release. # Until then, it needs to be pinned to v2.1.1. diff --git a/docker/install/ubuntu_install_zephyr.sh b/docker/install/ubuntu_install_zephyr.sh index 7e5aae96a38f..ddd1ea1c1734 100644 --- a/docker/install/ubuntu_install_zephyr.sh +++ b/docker/install/ubuntu_install_zephyr.sh @@ -44,9 +44,13 @@ sudo apt-get install -y cmake pip3 install west # Init ZephyrProject +# To keep in sync with the version +# defined in apps/microtvm/zephyr/template_project/microtvm_api_server.py +# We use `-branch` tag since it tracks the same version with extra patches for bugs. +ZEPHYR_VERSION="v2.5-branch" ZEPHYR_PROJECT_PATH=/opt/zephyrproject ZEPHYR_INIT_SCRIPT=$(find -name "ubuntu_init_zephyr_project.sh") -bash ${ZEPHYR_INIT_SCRIPT} ${ZEPHYR_PROJECT_PATH} v2.5-branch +bash ${ZEPHYR_INIT_SCRIPT} ${ZEPHYR_PROJECT_PATH} ${ZEPHYR_VERSION} cd ${ZEPHYR_PROJECT_PATH} # As part of the build process, Zephyr needs to touch some symlinks in zephyr/misc/generated/syscalls_links (this path is relative to the diff --git a/docs/arch/relay_op_strategy.rst b/docs/arch/relay_op_strategy.rst index c40251d22433..dbac7c821827 100644 --- a/docs/arch/relay_op_strategy.rst +++ b/docs/arch/relay_op_strategy.rst @@ -269,14 +269,14 @@ will then be chosen. Implementations with same priority level in this case leads to an undefined behavior, and any of them might be selected. The selection policy for ops with symbolic input shapes is still work in -progess. Currently, if any input tensor has a symbolic shape, only the +progress. Currently, if any input tensor has a symbolic shape, only the implementation with highest priority level will be used for this operator. This -will be updated after the implemention finishes. +will be updated after the implementation finishes. For debug purpose, you can add the following lines before you compile the Relay model to learn which implementation is used for each operator. .. code:: python - logging.getLogger("compile_engine").setLevel(logging.INFO) - logging.getLogger("compile_engine").addHandler(logging.StreamHandler(sys.stdout)) + logging.getLogger("te_compiler").setLevel(logging.INFO) + logging.getLogger("te_compiler").addHandler(logging.StreamHandler(sys.stdout)) diff --git a/docs/conf.py b/docs/conf.py index 766fda49997f..893d89c26156 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,7 @@ sys.path.insert(0, str(tvm_path.resolve() / "python")) sys.path.insert(0, str(tvm_path.resolve() / "vta" / "python")) +sys.path.insert(0, str(tvm_path.resolve() / "docs")) # -- General configuration ------------------------------------------------ @@ -258,6 +259,7 @@ def git_describe_version(original_version): "tensor_expr_get_started.py", "autotvm_matmul_x86.py", "auto_scheduler_matmul_x86.py", + "tensor_ir_blitz_course.py", "topi.pi", "cross_compilation_and_rpc.py", "relay_quick_start.py", @@ -466,5 +468,9 @@ def process_docstring(app, what, name, obj, options, lines): update_alias_docstring(name, obj, lines) +from legacy_redirect import build_legacy_redirect + + def setup(app): app.connect("autodoc-process-docstring", process_docstring) + app.connect("build-finished", build_legacy_redirect(tvm_path)) diff --git a/docs/contribute/git_howto.rst b/docs/contribute/git_howto.rst index 458573630aa5..765153be220b 100644 --- a/docs/contribute/git_howto.rst +++ b/docs/contribute/git_howto.rst @@ -23,7 +23,8 @@ Git Usage Tips Here are some tips for git workflow. -## How to resolve conflict with main +How to resolve a conflict with `main` +------------------------------------- - First rebase to most recent main diff --git a/docs/dev/how_to/pytest_target_parametrization.rst b/docs/dev/how_to/pytest_target_parametrization.rst index 6dfcaf3633be..3fbb69401d16 100644 --- a/docs/dev/how_to/pytest_target_parametrization.rst +++ b/docs/dev/how_to/pytest_target_parametrization.rst @@ -21,7 +21,7 @@ Python Target Parametrization Summary ------- -For any supported runtime, TVM should should produce numerically +For any supported runtime, TVM should produce numerically correct results. Therefore, when writing unit tests that validate the numeric output, these unit tests should be run on all supported runtimes. Since this is a very common use case, TVM has helper @@ -29,7 +29,7 @@ functions to parametrize unit tests such that they will run on all targets that are enabled and have a compatible device. A single python function in the test suite can expand to several -parametrized unit tests, each of which tests a single target device. +parameterized unit tests, each of which tests a single target device. In order for a test to be run, all of the following must be true. - The test exists in a file or directory that has been passed to @@ -129,11 +129,11 @@ marks are as follows. - ``@pytest.mark.gpu`` - Tags a function as using GPU capabilities. This has no effect on its own, but can be paired with command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict - which tests pytest will executed. This should not be called on its + which tests pytest will execute. This should not be called on its own, but is part of other marks used in unit-tests. - ``@tvm.testing.uses_gpu`` - Applies ``@pytest.mark.gpu``. This - should be used to mark a unit tests that may use the GPU, if one is + should be used to mark unit tests that may use the GPU, if one is present. This decorator is only needed for tests that explicitly loop over ``tvm.testing.enabled_targets()``, but that is no longer the preferred style of writing unit tests (see below). When using @@ -161,7 +161,7 @@ There also exists a ``tvm.testing.enabled_targets()`` that returns all targets that are enabled and runnable on the current machine, based on the environment variable ``TVM_TEST_TARGETS``, the build configuration, and the physical hardware present. Most current tests -explictly loop over the targets returned from ``enabled_targets()``, +explicitly loop over the targets returned from ``enabled_targets()``, but it should not be used for new tests. The pytest output for this style silently skips runtimes that are disabled in ``config.cmake``, or do not have a device on which they can run. In addition, the test diff --git a/docs/how_to/deploy/arm_compute_lib.rst b/docs/how_to/deploy/arm_compute_lib.rst index 831438273cca..a7ec8b9501c7 100644 --- a/docs/how_to/deploy/arm_compute_lib.rst +++ b/docs/how_to/deploy/arm_compute_lib.rst @@ -34,32 +34,31 @@ Before installing Arm Compute Library, it is important to know what architecture to determine this is to use `lscpu` and look for the "Model name" of the CPU. You can then use this to determine the architecture by looking online. -We recommend two different ways to build and install ACL: - -* Use the script located at `docker/install/ubuntu_install_arm_compute_lib.sh`. You can use this - script for building ACL from source natively or for cross-compiling the library on an x86 machine. - You may need to change the architecture of the device you wish to compile for by altering the - `target_arch` variable. Binaries will be built from source and installed to the location denoted by - `install_path`. -* Alternatively, you can download and use pre-built binaries from: +TVM only supports a single version of ACL, currently this is v21.08, there are two recommended ways to build and install +the required libraries: + +* Use the script located at `docker/install/ubuntu_download_arm_compute_lib_binaries.sh`. You can use this + script for downloading ACL binaries for the architecture and extensions specified in `target_lib`, these + will be installed to the location denoted by `install_path`. +* Alternatively, you can download the pre-built binaries from: https://github.com/ARM-software/ComputeLibrary/releases. When using this package, you will need to - select the binaries for the architecture you require and make sure they are visible to cmake. This - can be done like so: + select the binaries for the architecture and extensions you require, then make sure they are visible + to CMake: .. code:: bash cd /lib - mv ./linux--neon/* . + mv .//* . In both cases you will need to set USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR to the path where the ACL package -is located. Cmake will look in /path-to-acl/ along with /path-to-acl/lib and /path-to-acl/build for the +is located. CMake will look in /path-to-acl/ along with /path-to-acl/lib and /path-to-acl/build for the required binaries. See the section below for more information on how to use these configuration options. Building with ACL support ------------------------- -The current implementation has two separate build options in cmake. The reason for this split is +The current implementation has two separate build options in CMake. The reason for this split is because ACL cannot be used on an x86 machine. However, we still want to be able compile an ACL runtime module on an x86 machine. @@ -73,7 +72,7 @@ need to use USE_ARM_COMPUTE_LIB=ON on the x86 machine and USE_ARM_COMPUTE_LIB_GR AArch64 device. By default both options are set to OFF. Using USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR=ON will mean that ACL -binaries are searched for by cmake in the default locations +binaries are searched for by CMake in the default locations (see https://cmake.org/cmake/help/v3.4/command/find_library.html). In addition to this, /path-to-tvm-project/acl/ will also be searched. It is likely that you will need to set your own path to locate ACL. This can be done by specifying a path in the place of ON. diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index 23be3198bf7c..4fad42b0af76 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -123,6 +123,9 @@ The configuration of TVM can be modified by editing `config.cmake` and/or by pas - Note that apt-package append ``llvm-config`` with version number. For example, set ``set(USE_LLVM llvm-config-10)`` if you installed LLVM 10 package + - If you are a PyTorch user, it is recommended to set ``(USE_LLVM "/path/to/llvm-config --link-static")`` and ``set(HIDE_PRIVATE_SYMBOLS ON)`` + to avoid potential symbol conflicts between different versions LLVM used by TVM and PyTorch. + - We can then build tvm and related libraries. .. code:: bash diff --git a/docs/legacy_redirect.py b/docs/legacy_redirect.py new file mode 100644 index 000000000000..0f1340e5491f --- /dev/null +++ b/docs/legacy_redirect.py @@ -0,0 +1,272 @@ +# -*- coding: utf-8 -*- + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from string import Template +import json +import os + +legacy_redirects = [ + ["dev/benchmark.html", "../arch/benchmark.html"], + ["dev/convert_layout.html", "../arch/convert_layout.html"], + ["dev/debugger.html", "../arch/debugger.html"], + ["dev/device_target_interactions.html", "../arch/device_target_interactions.html"], + ["dev/frontend/tensorflow.html", "../../arch/frontend/tensorflow.html"], + ["dev/hybrid_script.html", "../arch/hybrid_script.html"], + ["dev/index.html", "../arch/index.html"], + ["dev/inferbound.html", "../arch/inferbound.html"], + [ + "dev/introduction_to_module_serialization.html", + "../arch/introduction_to_module_serialization.html", + ], + ["dev/microtvm_design.html", "../arch/microtvm_design.html"], + ["dev/model_library_format.html", "../arch/model_library_format.html"], + ["dev/pass_infra.html", "../arch/pass_infra.html"], + ["dev/relay_intro.html", "../arch/relay_intro.html"], + ["dev/relay_op_strategy.html", "../arch/relay_op_strategy.html"], + ["dev/runtime.html", "../arch/runtime.html"], + ["dev/runtimes/vulkan.html", "../../arch/runtimes/vulkan.html"], + ["dev/security.html", "../arch/security.html"], + ["dev/virtual_machine.html", "../arch/virtual_machine.html"], + ["dev/how_to.html", "index.html"], + ["dev/pytest_target_parametrization.html", "how_to/pytest_target_parametrization.html"], + ["dev/relay_add_op.html", "how_to/relay_add_op.html"], + ["dev/relay_add_pass.html", "how_to/relay_add_pass.html"], + ["dev/relay_bring_your_own_codegen.html", "how_to/relay_bring_your_own_codegen.html"], + ["dev/codebase_walkthrough.html", "tutorial/codebase_walkthrough.html"], + ["deploy/android.html", "../how_to/deploy/android.html"], + ["deploy/arm_compute_lib.html", "../how_to/deploy/arm_compute_lib.html"], + ["deploy/bnns.html", "../how_to/deploy/bnns.html"], + ["deploy/cpp_deploy.html", "../how_to/deploy/cpp_deploy.html"], + ["deploy/hls.html", "../how_to/deploy/hls.html"], + ["deploy/index.html", "../how_to/deploy/index.html"], + ["deploy/integrate.html", "../how_to/deploy/integrate.html"], + ["deploy/tensorrt.html", "../how_to/deploy/tensorrt.html"], + ["deploy/vitis_ai.html", "../how_to/deploy/vitis_ai.html"], + ["profiling/index.html", "../how_to/profile/index.html"], + ["profiling/papi.html", "../how_to/profile/papi.html"], + ["api/links.html", "../reference/api/links.html"], + ["api/python/auto_scheduler.html", "../../reference/api/python/auto_scheduler.html"], + ["api/python/autotvm.html", "../../reference/api/python/autotvm.html"], + ["api/python/contrib.html", "../../reference/api/python/contrib.html"], + ["api/python/driver.html", "../../reference/api/python/driver.html"], + ["api/python/error.html", "../../reference/api/python/error.html"], + ["api/python/graph_executor.html", "../../reference/api/python/graph_executor.html"], + ["api/python/index.html", "../../reference/api/python/index.html"], + ["api/python/ir.html", "../../reference/api/python/ir.html"], + ["api/python/micro.html", "../../reference/api/python/micro.html"], + ["api/python/ndarray.html", "../../reference/api/python/ndarray.html"], + ["api/python/relay/analysis.html", "../../../reference/api/python/relay/analysis.html"], + ["api/python/relay/backend.html", "../../../reference/api/python/relay/backend.html"], + [ + "api/python/relay/dataflow_pattern.html", + "../../../reference/api/python/relay/dataflow_pattern.html", + ], + ["api/python/relay/frontend.html", "../../../reference/api/python/relay/frontend.html"], + ["api/python/relay/image.html", "../../../reference/api/python/relay/image.html"], + ["api/python/relay/index.html", "../../../reference/api/python/relay/index.html"], + ["api/python/relay/nn.html", "../../../reference/api/python/relay/nn.html"], + ["api/python/relay/testing.html", "../../../reference/api/python/relay/testing.html"], + ["api/python/relay/transform.html", "../../../reference/api/python/relay/transform.html"], + ["api/python/relay/vision.html", "../../../reference/api/python/relay/vision.html"], + ["api/python/rpc.html", "../../reference/api/python/rpc.html"], + ["api/python/runtime.html", "../../reference/api/python/runtime.html"], + ["api/python/target.html", "../../reference/api/python/target.html"], + ["api/python/te.html", "../../reference/api/python/te.html"], + ["api/python/tir.html", "../../reference/api/python/tir.html"], + ["api/python/topi.html", "../../reference/api/python/topi.html"], + ["api/python/vta/index.html", "../../../reference/api/python/vta/index.html"], + ["langref/hybrid_script.html", "../reference/langref/hybrid_script.html"], + ["langref/index.html", "../reference/langref/index.html"], + ["langref/relay_adt.html", "../reference/langref/relay_adt.html"], + ["langref/relay_expr.html", "../reference/langref/relay_expr.html"], + ["langref/relay_op.html", "../reference/langref/relay_op.html"], + ["langref/relay_pattern.html", "../reference/langref/relay_pattern.html"], + ["langref/relay_type.html", "../reference/langref/relay_type.html"], + ["microtvm/index.html", "../topic/microtvm/index.html"], + ["vta/dev/config.html", "../../topic/vta/dev/config.html"], + ["vta/dev/hardware.html", "../../topic/vta/dev/hardware.html"], + ["vta/dev/index.html", "../../topic/vta/dev/index.html"], + ["vta/index.html", "../topic/vta/index.html"], + ["vta/install.html", "../topic/vta/install.html"], + ["tutorials/frontend/from_caffe2.html", "../../how_to/compile_models/from_caffe2.html"], + ["tutorials/frontend/from_coreml.html", "../../how_to/compile_models/from_coreml.html"], + ["tutorials/frontend/from_darknet.html", "../../how_to/compile_models/from_darknet.html"], + ["tutorials/frontend/from_keras.html", "../../how_to/compile_models/from_keras.html"], + ["tutorials/frontend/from_mxnet.html", "../../how_to/compile_models/from_mxnet.html"], + ["tutorials/frontend/from_onnx.html", "../../how_to/compile_models/from_onnx.html"], + ["tutorials/frontend/from_paddle.html", "../../how_to/compile_models/from_paddle.html"], + ["tutorials/frontend/from_pytorch.html", "../../how_to/compile_models/from_pytorch.html"], + ["tutorials/frontend/from_tensorflow.html", "../../how_to/compile_models/from_tensorflow.html"], + ["tutorials/frontend/from_tflite.html", "../../how_to/compile_models/from_tflite.html"], + [ + "tutorials/frontend/deploy_model_on_android.html", + "../../how_to/deploy_models/deploy_model_on_android.html", + ], + [ + "tutorials/frontend/deploy_model_on_rasp.html", + "../../how_to/deploy_models/deploy_model_on_rasp.html", + ], + [ + "tutorials/frontend/deploy_object_detection_pytorch.html", + "../../how_to/deploy_models/deploy_object_detection_pytorch.html", + ], + [ + "tutorials/frontend/deploy_prequantized.html", + "../../how_to/deploy_models/deploy_prequantized.html", + ], + [ + "tutorials/frontend/deploy_prequantized_tflite.html", + "../../how_to/deploy_models/deploy_prequantized_tflite.html", + ], + [ + "tutorials/frontend/deploy_quantized.html", + "../../how_to/deploy_models/deploy_quantized.html", + ], + ["tutorials/frontend/deploy_sparse.html", "../../how_to/deploy_models/deploy_sparse.html"], + [ + "tutorials/frontend/deploy_ssd_gluoncv.html", + "../../how_to/deploy_models/deploy_ssd_gluoncv.html", + ], + [ + "tutorials/dev/bring_your_own_datatypes.html", + "../../how_to/extend_tvm/bring_your_own_datatypes.html", + ], + [ + "tutorials/dev/low_level_custom_pass.html", + "../../how_to/extend_tvm/low_level_custom_pass.html", + ], + ["tutorials/dev/use_pass_infra.html", "../../how_to/extend_tvm/use_pass_infra.html"], + ["tutorials/dev/use_pass_instrument.html", "../../how_to/extend_tvm/use_pass_instrument.html"], + ["tutorials/optimize/opt_conv_cuda.html", "../../how_to/optimize_operators/opt_conv_cuda.html"], + [ + "tutorials/optimize/opt_conv_tensorcore.html", + "../../how_to/optimize_operators/opt_conv_tensorcore.html", + ], + ["tutorials/optimize/opt_gemm.html", "../../how_to/optimize_operators/opt_gemm.html"], + [ + "tutorials/auto_scheduler/tune_conv2d_layer_cuda.html", + "../../how_to/tune_with_autoscheduler/tune_conv2d_layer_cuda.html", + ], + [ + "tutorials/auto_scheduler/tune_network_arm.html", + "../../how_to/tune_with_autoscheduler/tune_network_arm.html", + ], + [ + "tutorials/auto_scheduler/tune_network_cuda.html", + "../../how_to/tune_with_autoscheduler/tune_network_cuda.html", + ], + [ + "tutorials/auto_scheduler/tune_network_mali.html", + "../../how_to/tune_with_autoscheduler/tune_network_mali.html", + ], + [ + "tutorials/auto_scheduler/tune_network_x86.html", + "../../how_to/tune_with_autoscheduler/tune_network_x86.html", + ], + [ + "tutorials/auto_scheduler/tune_sparse_x86.html", + "../../how_to/tune_with_autoscheduler/tune_sparse_x86.html", + ], + [ + "tutorials/autotvm/tune_conv2d_cuda.html", + "../../how_to/tune_with_autotvm/tune_conv2d_cuda.html", + ], + ["tutorials/autotvm/tune_relay_arm.html", "../../how_to/tune_with_autotvm/tune_relay_arm.html"], + [ + "tutorials/autotvm/tune_relay_cuda.html", + "../../how_to/tune_with_autotvm/tune_relay_cuda.html", + ], + [ + "tutorials/autotvm/tune_relay_mobile_gpu.html", + "../../how_to/tune_with_autotvm/tune_relay_mobile_gpu.html", + ], + ["tutorials/autotvm/tune_relay_x86.html", "../../how_to/tune_with_autotvm/tune_relay_x86.html"], + ["tutorials/micro/micro_autotune.html", "../../how_to/work_with_microtvm/micro_autotune.html"], + [ + "tutorials/micro/micro_reference_vm.html", + "../../how_to/work_with_microtvm/micro_reference_vm.html", + ], + ["tutorials/micro/micro_tflite.html", "../../how_to/work_with_microtvm/micro_tflite.html"], + ["tutorials/frontend/build_gcn.html", "../../how_to/work_with_relay/build_gcn.html"], + [ + "tutorials/frontend/using_external_lib.html", + "../../how_to/work_with_relay/using_external_lib.html", + ], + ["tutorials/language/extern_op.html", "../../how_to/work_with_schedules/extern_op.html"], + ["tutorials/language/intrin_math.html", "../../how_to/work_with_schedules/intrin_math.html"], + ["tutorials/language/reduction.html", "../../how_to/work_with_schedules/reduction.html"], + ["tutorials/language/scan.html", "../../how_to/work_with_schedules/scan.html"], + [ + "tutorials/language/schedule_primitives.html", + "../../how_to/work_with_schedules/schedule_primitives.html", + ], + ["tutorials/language/tedd.html", "../../how_to/work_with_schedules/tedd.html"], + ["tutorials/language/tensorize.html", "../../how_to/work_with_schedules/tensorize.html"], + ["tutorials/language/tuple_inputs.html", "../../how_to/work_with_schedules/tuple_inputs.html"], + [ + "tutorials/get_started/auto_scheduler_matmul_x86.html", + "../../tutorial/auto_scheduler_matmul_x86.html", + ], + ["tutorials/get_started/autotvm_matmul_x86.html", "../../tutorial/autotvm_matmul_x86.html"], + ["tutorials/get_started/autotvm_relay_x86.html", "../../tutorial/autotvm_relay_x86.html"], + [ + "tutorials/get_started/cross_compilation_and_rpc.html", + "../../tutorial/cross_compilation_and_rpc.html", + ], + ["tutorials/get_started/install.html", "../../tutorial/install.html"], + ["tutorials/topi/intro_topi.html", "../../tutorial/intro_topi.html"], + ["tutorials/get_started/introduction.html", "../../tutorial/introduction.html"], + ["tutorials/get_started/relay_quick_start.html", "../../tutorial/relay_quick_start.html"], + [ + "tutorials/get_started/tensor_expr_get_started.html", + "../../tutorial/tensor_expr_get_started.html", + ], + [ + "tutorials/get_started/tvmc_command_line_driver.html", + "../../tutorial/tvmc_command_line_driver.html", + ], +] + +redirect_template = """ + + + + + + + +""" + + +def build_legacy_redirect(tvm_path): + def legacy_redirect(app, docname): # Sphinx expects two arguments + if app.builder.name == "html": + + src = Template(redirect_template) + + for frm, to in legacy_redirects: + frm = tvm_path.resolve() / "docs" / "_build" / "html" / frm + redirect = src.substitute({"to": to}) + os.makedirs(os.path.dirname(frm), exist_ok=True) + with open(frm, "w") as f: + f.write(redirect) + + return legacy_redirect diff --git a/docs/reference/api/python/relay/backend.rst b/docs/reference/api/python/relay/backend.rst index ffe8a9a8ce79..e717ee10ffab 100644 --- a/docs/reference/api/python/relay/backend.rst +++ b/docs/reference/api/python/relay/backend.rst @@ -23,7 +23,7 @@ tvm.relay.backend .. automodule:: tvm.relay.backend.interpreter :members: -.. automodule:: tvm.relay.backend.compile_engine +.. automodule:: tvm.relay.backend.te_compiler :members: .. automodule:: tvm.relay.backend.graph_executor_codegen diff --git a/gallery/how_to/work_with_microtvm/micro_autotune.py b/gallery/how_to/work_with_microtvm/micro_autotune.py index e7a1fa84a110..d3106712aa99 100644 --- a/gallery/how_to/work_with_microtvm/micro_autotune.py +++ b/gallery/how_to/work_with_microtvm/micro_autotune.py @@ -113,12 +113,9 @@ # choose other options by choosing from `PLATFORM` list. # -repo_root = pathlib.Path( - subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() -) module_loader = tvm.micro.AutoTvmModuleLoader( - template_project_dir=repo_root / "src" / "runtime" / "crt" / "host", + template_project_dir=pathlib.Path(tvm.micro.get_microtvm_template_projects("crt")), project_options={"verbose": False}, ) builder = tvm.autotvm.LocalBuilder( @@ -134,7 +131,7 @@ # Compiling for physical hardware # -------------------------------------------------------------------------- # module_loader = tvm.micro.AutoTvmModuleLoader( -# template_project_dir=repo_root / "apps" / "microtvm" / "zephyr" / "template_project", +# template_project_dir=pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")), # project_options={ # "zephyr_board": BOARD, # "west_cmd": "west", @@ -183,7 +180,7 @@ temp_dir = tvm.contrib.utils.tempdir() project = tvm.micro.generate_project( - str(repo_root / "src" / "runtime" / "crt" / "host"), + str(tvm.micro.get_microtvm_template_projects("crt")), lowered, temp_dir / "project", {"verbose": False}, @@ -192,7 +189,7 @@ # Compiling for physical hardware # -------------------------------------------------------------------------- # project = tvm.micro.generate_project( -# str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"), +# str(tvm.micro.get_microtvm_template_projects("zephyr")), # lowered, # temp_dir / "project", # { @@ -226,7 +223,7 @@ temp_dir = tvm.contrib.utils.tempdir() project = tvm.micro.generate_project( - str(repo_root / "src" / "runtime" / "crt" / "host"), + str(tvm.micro.get_microtvm_template_projects("crt")), lowered_tuned, temp_dir / "project", {"verbose": False}, @@ -235,7 +232,7 @@ # Compiling for physical hardware # -------------------------------------------------------------------------- # project = tvm.micro.generate_project( -# str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"), +# str(tvm.micro.get_microtvm_template_projects("zephyr")), # lowered_tuned, # temp_dir / "project", # { diff --git a/gallery/how_to/work_with_microtvm/micro_tflite.py b/gallery/how_to/work_with_microtvm/micro_tflite.py index cab105cb450f..35b08d87b9ee 100644 --- a/gallery/how_to/work_with_microtvm/micro_tflite.py +++ b/gallery/how_to/work_with_microtvm/micro_tflite.py @@ -269,10 +269,7 @@ import subprocess import pathlib -repo_root = pathlib.Path( - subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() -) -template_project_path = repo_root / "src" / "runtime" / "crt" / "host" +template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("crt")) project_options = {} # You can use options to provide platform-specific options through TVM. # Compiling for physical hardware (or an emulated board, like the mps_an521) @@ -280,7 +277,7 @@ # For physical hardware, you can try out the Zephyr platform by using a different template project # and options: # -# template_project_path = repo_root / "apps" / "microtvm" / "zephyr" / "template_project" +# template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) # project_options = {"project_type": "host_driven", zephyr_board": "nucleo_f746zg"}} # Create a temporary directory diff --git a/gallery/tutorial/autotvm_relay_x86.py b/gallery/tutorial/autotvm_relay_x86.py index 8b9c45c2a859..67b832cc226d 100644 --- a/gallery/tutorial/autotvm_relay_x86.py +++ b/gallery/tutorial/autotvm_relay_x86.py @@ -106,7 +106,7 @@ # TVMC has adopted NumPy's ``.npz`` format for both input and output data. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -278,6 +278,7 @@ from tvm.autotvm.tuner import XGBTuner from tvm import autotvm +################################################################################ # Set up some basic parameters for the runner. The runner takes compiled code # that is generated with a specific set of parameters and measures the # performance of it. ``number`` specifies the number of different @@ -303,6 +304,7 @@ enable_cpu_cache_flush=True, ) +################################################################################ # Create a simple structure for holding tuning options. We use an XGBoost # algorithim for guiding the search. For a production job, you will want to set # the number of trials to be larger than the value of 10 used here. For CPU we @@ -426,6 +428,7 @@ for rank in ranks[0:5]: print("class='%s' with probability=%f" % (labels[rank], scores[rank])) +################################################################################ # Verifying that the predictions are the same: # # .. code-block:: bash diff --git a/gallery/tutorial/tensor_expr_get_started.py b/gallery/tutorial/tensor_expr_get_started.py index fda332cb63ba..e4d947d1c488 100644 --- a/gallery/tutorial/tensor_expr_get_started.py +++ b/gallery/tutorial/tensor_expr_get_started.py @@ -133,7 +133,7 @@ ################################################################################ # Let's run the function, and compare the output to the same computation in -# numpy. The compiled TVM function is exposes a concise C API that can be invoked +# numpy. The compiled TVM function exposes a concise C API that can be invoked # from any language. We begin by creating a device, which is a device (CPU in this # example) that TVM can compile the schedule to. In this case the device is an # LLVM CPU target. We can then initialize the tensors in our device and @@ -258,8 +258,8 @@ def evaluate_addition(func, target, optimization, log): print(tvm.lower(s, [A, B, C], simple_mode=True)) ################################################################################ -# Comparing the Diferent Schedules -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Comparing the Different Schedules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We can now compare the different schedules baseline = log[0][1] @@ -347,7 +347,7 @@ def evaluate_addition(func, target, optimization, log): fadd = tvm.build(s, [A, B, C], target=tgt_gpu, name="myadd") ################################################################################ - # The compiled TVM function is exposes a concise C API that can be invoked from + # The compiled TVM function exposes a concise C API that can be invoked from # any language. # # We provide a minimal array API in python to aid quick testing and prototyping. diff --git a/gallery/tutorial/tensor_ir_blitz_course.py b/gallery/tutorial/tensor_ir_blitz_course.py new file mode 100644 index 000000000000..e9a0801f34a8 --- /dev/null +++ b/gallery/tutorial/tensor_ir_blitz_course.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _tir_blitz: + +Blitz Course to TensorIR +======================== +**Author**: `Siyuan Feng `_ + +TensorIR is a domain specific language for deep learning programs serving two broad purposes: + +- An implementation for transforming and optimizing programs on various hardware backends. + +- An abstraction for automatic tensorized program optimization. + +""" + +import tvm +from tvm.ir.module import IRModule +from tvm.script import tir as T +import numpy as np + +################################################################################################ +# IRModule +# -------- +# An IRModule is the central data structure in TVM, which contains deep learning programs. +# It is the basic object of interest of IR transformation and model building. +# +# .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_life_of_irmodule.png +# :align: center +# :width: 85% +# +# This is the life cycle of an IRModule, which can be created from TVMScript. TensorIR schedule +# primitives and passes are two major ways to transform an IRModule. Also, a sequence of +# transformations on an IRModule is acceptable. Note that we can print an IRModule at **ANY** stage +# to TVMScript. After all transformations and optimizations are complete, we can build the IRModule +# to a runnable module to deploy on target devices. +# +# Based on the design of TensorIR and IRModule, we are able to create a new programming method: +# +# 1. Write a program by TVMScript in a python-AST based syntax. +# +# 2. Transform and optimize a program with python api. +# +# 3. Interactively inspect and try the performance with an imperative style transformation API. + + +################################################################################################ +# Create an IRModule +# ------------------ +# IRModule can be created by writing TVMScript, which is a round-trippable syntax for TVM IR. +# +# Different than creating a computational expression by Tensor Expression +# (:ref:`tutorial-tensor-expr-get-started`), TensorIR allow users to program through TVMScript, +# a language embedded in python AST. The new method makes it possible to write complex programs +# and further schedule and optimize it. +# +# Following is a simple example for vector addition. +# + + +@tvm.script.ir_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + + +ir_module = MyModule +print(type(ir_module)) +print(ir_module.script()) + +################################################################################################ +# Besides, we can also use tensor expression DSL to write simple operators, and convert them +# to an IRModule. +# + +from tvm import te + +A = te.placeholder((8,), dtype="float32", name="A") +B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B") +func = te.create_prim_func([A, B]) +ir_module_from_te = IRModule({"main": func}) +print(ir_module_from_te.script()) + + +################################################################################################ +# Build and Run an IRModule +# ------------------------- +# We can build the IRModule into a runnable module with specific target backends. +# + +mod = tvm.build(ir_module, target="llvm") # The module for CPU backends. +print(type(mod)) + +################################################################################################ +# Prepare the input array and output array, then run the module. +# + +a = tvm.nd.array(np.arange(8).astype("float32")) +b = tvm.nd.array(np.zeros((8,)).astype("float32")) +mod(a, b) +print(a) +print(b) + + +################################################################################################ +# Transform an IRModule +# --------------------- +# The IRModule is the central data structure for program optimization, which can be transformed +# by :code:`Schedule`. +# A schedule contains multiple primitive methods to interactively transform the program. +# Each primitive transforms the program in certain ways to bring additional performance optimizations. +# +# .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_tensor_ir_opt_flow.png +# :align: center +# :width: 100% +# +# The image above is a typical workflow for optimizing a tensor program. First, we need to create a +# schedule on the initial IRModule created from either TVMScript or Tensor Expression. Then, a +# sequence of schedule primitives will help to improve the performance. And at last, we can lower +# and build it into a runnable module. +# +# Here we just demostrate a very simple tranformation. First we create schedule on the input `ir_module`. + +sch = tvm.tir.Schedule(ir_module) +print(type(sch)) + +################################################################################################ +# Tile the loop into 3 loops and print the result. + +# Get block by its name +block_b = sch.get_block("B") +# Get loops surronding the block +(i,) = sch.get_loops(block_b) +# Tile the loop nesting. +i_0, i_1, i_2 = sch.split(i, factors=[2, 2, 2]) +print(sch.mod.script()) + + +################################################################################################ +# We can also reorder the loops. Now we move loop `i_2` to outside of `i_1`. +sch.reorder(i_0, i_2, i_1) +print(sch.mod.script()) + + +################################################################################################ +# Transform to a GPU program +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# If we want to deploy models on GPUs, thread binding is necessary. Fortunately, we can +# also use primitives and do incrementally transformation. +# + +sch.bind(i_0, "blockIdx.x") +sch.bind(i_2, "threadIdx.x") +print(sch.mod.script()) + + +################################################################################################ +# After binding the threads, now build the IRModule with :code:`cuda` backends. +ctx = tvm.cuda(0) +cuda_mod = tvm.build(sch.mod, target="cuda") +cuda_a = tvm.nd.array(np.arange(8).astype("float32"), ctx) +cuda_b = tvm.nd.array(np.zeros((8,)).astype("float32"), ctx) +cuda_mod(cuda_a, cuda_b) +print(cuda_a) +print(cuda_b) diff --git a/gallery/tutorial/tvmc_command_line_driver.py b/gallery/tutorial/tvmc_command_line_driver.py index 7a0b97895e4f..facb978cea67 100644 --- a/gallery/tutorial/tvmc_command_line_driver.py +++ b/gallery/tutorial/tvmc_command_line_driver.py @@ -174,10 +174,10 @@ # data types. For this reason, most models require some pre and # post-processing, to ensure the input is valid and to interpret the output. # TVMC has adopted NumPy's ``.npz`` format for both input and output data. This -# is a well-supported NumPy format to serialize multiple arrays into a file +# is a well-supported NumPy format to serialize multiple arrays into a file. # # As input for this tutorial, we will use the image of a cat, but you can feel -# free to substitute image for any of your choosing. +# free to substitute this image for any of your choosing. # # .. image:: https://s3.amazonaws.com/model-server/inputs/kitten.jpg # :height: 224px @@ -197,8 +197,8 @@ # requirement for the script. # # .. code-block:: python -# :caption: preprocess.py -# :name: preprocess.py +# :caption: preprocess.py +# :name: preprocess.py # # #!python ./preprocess.py # from tvm.contrib.download import download_testdata diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 6b350e25e167..f55a0651a870 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -249,7 +249,7 @@ IntSet UnionLowerBound(const Array& sets); Array UnionRegionLowerBound(const Array>& nd_int_sets); /*! - * \brief Create an union set of all sets + * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return the set after intersected */ diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 715c96eb6ea5..f6c15f9590df 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -489,7 +489,7 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization." + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " << "If the key is defined check that its type matches the declared type."; throw AttrError(os.str()); } @@ -806,7 +806,7 @@ class AttrsNode : public BaseAttrsNode { ICHECK_EQ(args.size() % 2, 0); const int kLinearSearchBound = 16; int hit_count = 0; - // applies two stratgies to lookup + // applies two strategies to lookup if (args.size() < kLinearSearchBound) { // linear search. auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 19358552df10..b809843f4157 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode { } Array Build(const Array& build_inputs) final { + ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); } diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 7ba3c207e349..60c6898f000b 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode { // `f_size` is not visited } - static constexpr const char* _type_key = "meta_schedule.PyDatabase"; - TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); - - Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); } + Workload CommitWorkload(const IRModule& mod) final { + ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!"; + return f_commit_workload(mod); + } - void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); } + void CommitTuningRecord(const TuningRecord& record) final { + ICHECK(f_commit_tuning_record != nullptr) + << "PyDatabase's CommitTuningRecord method not implemented!"; + f_commit_tuning_record(record); + } Array GetTopK(const Workload& workload, int top_k) final { + ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!"; return f_get_top_k(workload, top_k); } - int64_t Size() final { return f_size(); } + int64_t Size() final { + ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!"; + return f_size(); + } + + static constexpr const char* _type_key = "meta_schedule.PyDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); }; /*! diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h new file mode 100644 index 000000000000..c6925eed91c4 --- /dev/null +++ b/include/tvm/meta_schedule/integration.h @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_INTEGRATION_H_ +#define TVM_META_SCHEDULE_INTEGRATION_H_ + +#include +#include + +#include + +namespace tvm { +namespace meta_schedule { + +/**************** ExtractedTask ****************/ + +/*! + * \brief A tuning task extracted from the high-level IR + */ +class ExtractedTaskNode : public runtime::Object { + public: + /*! \brief The name of the task extracted */ + String task_name; + /*! \brief The high-level IR */ + IRModule mod; + /*! \brief A list of low-level IRs that the high-level IR could potentially dispatch to */ + Array dispatched; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("task_name", &task_name); + v->Visit("mod", &mod); + v->Visit("dispatched", &dispatched); + } + + static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); +}; + +/*! + * \brief Managed reference to ExtractedTaskNode + * \sa ExtractedTaskNode + */ +class ExtractedTask : public runtime::ObjectRef { + public: + /*! + * \brief Constructor. The name of the task extracted + * \brief The high-level IR + * \brief A list of low-level IRs that the high-level IR could potentially dispatch to + */ + explicit ExtractedTask(String task_name, IRModule mod, Array dispatched); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExtractedTask, runtime::ObjectRef, ExtractedTaskNode); +}; + +/**************** MetaScheduleContext ****************/ + +/*! + * \brief A context manager interface for the integration + */ +class MetaScheduleContextNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~MetaScheduleContextNode() = default; + /*! + * \brief The entry point of the integration + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to. + * NullOpt means the dispatch needs to be done in the context. + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + virtual Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) = 0; + + static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; + TVM_DECLARE_BASE_OBJECT_INFO(MetaScheduleContextNode, runtime::Object); +}; + +/*! + * \brief Managed reference to MetaScheduleContextNode + * \sa MetaScheduleContextNode + */ +class MetaScheduleContext : public runtime::ObjectRef { + friend class MetaScheduleContextInternal; + friend class With; + + public: + /*! \brief Default destructor */ + virtual ~MetaScheduleContext() = default; + /*! + * \brief The context manager in the current scope + * \return The MetaScheduleContext in the current scope. NullOpt if it's currently not under any + * MetaScheduleContext. + */ + static Optional Current(); + /*! + * \brief The entry point of the integration workflow. The compilation process of the high-level + * IR should call this method for task extraction and for feedback hints + * \param task_name The name of the task + * \param mod The high-level IR + * \param dispatched A list of low-level IRs that the high-level IR could potentially dispatch to + * \return There are different types of the output + * 1) NullOpt if there is no feedback hint + * 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc + * 3) relay::Function if `mod` should be dispatched to BYOC workflow + * 4) IRModule for unified dispatch + */ + static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + Optional> dispatched); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetaScheduleContext, runtime::ObjectRef, + MetaScheduleContextNode); + + protected: + /*! \brief Default constructor */ + MetaScheduleContext() = default; + /*! \brief Entering the scope of the context manager */ + void EnterWithScope(); + /*! \brief Exiting the scope of the context manager */ + void ExitWithScope(); +}; + +/**************** TaskExtraction ****************/ + +/*! + * \brief An integration context for task extraction + */ +class TaskExtractionNode : public MetaScheduleContextNode { + public: + /*! \brief The extracted tasks */ + Array tasks{nullptr}; + + void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; + TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode); +}; + +/*! + * \brief Managed reference to TaskExtractionNode + * \sa TaskExtractionNode + */ +class TaskExtraction : public MetaScheduleContext { + public: + /*! \brief The path to a cache file storing extracted tasks */ + TaskExtraction(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext, + TaskExtractionNode); +}; + +/**************** ApplyHistoryBest ****************/ + +/*! + * \brief An integration context that allows application of historically best records from a + * database + */ +class ApplyHistoryBestNode : public MetaScheduleContextNode { + public: + /*! \brief The database to be queried from */ + Database database{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("database", &database); // + } + + // Inherited from base class + Optional Query(runtime::String task_name, IRModule mod, + Optional> dispatched) final; + + static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; + TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, MetaScheduleContextNode); +}; + +/*! + * \brief Managed reference to ApplyHistoryBestNode + * \sa ApplyHistoryBestNode + */ +class ApplyHistoryBest : public MetaScheduleContext { + public: + /*! + * \brief Constructor + * \param database The database to be queried from + */ + explicit ApplyHistoryBest(Database database); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ApplyHistoryBest, MetaScheduleContext, + ApplyHistoryBestNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_INTEGRATION_H_ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c1451ae977d4..b154195f43a6 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode { // `f_run` is not visited } - Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + Array Run(Array runner_inputs) final { + ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; + return f_run(runner_inputs); + } static constexpr const char* _type_key = "meta_schedule.PyRunner"; TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 941dae4336e1..0f3e9298d11a 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -187,20 +187,30 @@ class PySearchStrategyNode : public SearchStrategyNode { } void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySearchStrategy's InitializeWithTuneContext method not implemented!"; this->f_initialize_with_tune_context(context); } void PreTuning(const Array& design_spaces) final { + ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!"; this->f_pre_tuning(design_spaces); } - void PostTuning() final { this->f_post_tuning(); } + void PostTuning() final { + ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!"; + this->f_post_tuning(); + } Optional> GenerateMeasureCandidates() final { + ICHECK(f_generate_measure_candidates != nullptr) + << "PySearchStrategy's GenerateMeasureCandidates method not implemented!"; return this->f_generate_measure_candidates(); } void NotifyRunnerResults(const Array& results) final { + ICHECK(f_notify_runner_results != nullptr) + << "PySearchStrategy's NotifyRunnerResults method not implemented!"; this->f_notify_runner_results(results); } diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 3dc181e05d8a..eadf5e91506c 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { } void InitializeWithTuneContext(const TuneContext& tune_context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PySpaceGenerator's InitializeWithTuneContext !"; f_initialize_with_tune_context(tune_context); } Array GenerateDesignSpace(const IRModule& mod) final { + ICHECK(f_generate_design_space != nullptr) + << "PySpaceGenerator's GenerateDesignSpace method not implemented!"; return f_generate_design_space(mod); } diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index a2db24e31a87..64ba3ddeafb1 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -87,6 +87,12 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief Auto-tuning. */ virtual void Tune(); + /*! + * \brief Initialize modules of the given task. + * \param task_id The task id to be initialized. + */ + virtual void InitializeTask(int task_id); + /*! * \brief Set specific task to be stopped. * \param task_id The task id to be stopped. @@ -116,12 +122,17 @@ class TaskSchedulerNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object); }; +class TaskScheduler; + /*! \brief The task scheduler with customized methods on the python-side. */ class PyTaskSchedulerNode : public TaskSchedulerNode { public: /*! \brief The function type of `Tune` method. */ using FTune = runtime::TypedPackedFunc; + /*! \brief The function type of `InitializeTask` method. */ + using FInitializeTask = runtime::TypedPackedFunc; + /*! * \brief The function type of `SetTaskStopped` method. * \param task_id The task id to be stopped. @@ -149,6 +160,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { /*! \brief The packed function to the `Tune` funcion. */ FTune f_tune; + /*! \brief The packed function to the `InitializeTask` funcion. */ + FInitializeTask f_initialize_task; /*! \brief The packed function to the `SetTaskStopped` function. */ FSetTaskStopped f_set_task_stopped; /*! \brief The packed function to the `IsTaskRunning` function. */ @@ -160,29 +173,55 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { void VisitAttrs(tvm::AttrVisitor* v) { // `f_tune` is not visited + // `f_initialize_task` is not visited // `f_set_task_stopped` is not visited // `f_is_task_running` is not visited // `f_join_running_task` is not visited // `f_next_task_id` is not visited } - void Tune() final { // - f_tune(); + void Tune() final { + if (f_tune == nullptr) { + TaskSchedulerNode::Tune(); + } else { + f_tune(); + } + } + + void InitializeTask(int task_id) final { + if (f_initialize_task == nullptr) { + TaskSchedulerNode::InitializeTask(task_id); + } else { + f_initialize_task(task_id); + } } - void SetTaskStopped(int task_id) final { // - f_set_task_stopped(task_id); + void SetTaskStopped(int task_id) final { + if (f_set_task_stopped == nullptr) { + TaskSchedulerNode::SetTaskStopped(task_id); + } else { + f_set_task_stopped(task_id); + } } - bool IsTaskRunning(int task_id) final { // - return f_is_task_running(task_id); + bool IsTaskRunning(int task_id) final { + if (f_is_task_running == nullptr) { + return TaskSchedulerNode::IsTaskRunning(task_id); + } else { + return f_is_task_running(task_id); + } } - void JoinRunningTask(int task_id) final { // - f_join_running_task(task_id); + void JoinRunningTask(int task_id) final { + if (f_join_running_task == nullptr) { + return TaskSchedulerNode::JoinRunningTask(task_id); + } else { + return f_join_running_task(task_id); + } } - int NextTaskId() final { // + int NextTaskId() final { + ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!"; return f_next_task_id(); } @@ -203,10 +242,17 @@ class TaskScheduler : public runtime::ObjectRef { * \param runner The runner of the scheduler. * \param database The database of the scheduler. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, Builder builder, Runner runner, - Database database); + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database); // TVM_DLL static TaskScheduler PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 85ac3f36ff60..f88ca8ef6380 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -116,17 +116,6 @@ struct CompilerAttrs : public tvm::AttrsNode { } }; -/*! - * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. - */ -struct TIRCallAttrs : public tvm::AttrsNode { - /*! \brief The metadata attached to the call node. */ - Map metadata; - - TVM_DECLARE_ATTRS(TIRCallAttrs, "relay.attrs.TIRCallAttrs") { - TVM_ATTR_FIELD(metadata).describe("Metadata attached to the TIR function call."); - } -}; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h new file mode 100644 index 000000000000..2b02c6a5edac --- /dev/null +++ b/include/tvm/relay/attrs/call.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/call.h + * \brief Attribute for call_lowered operator. + */ +#ifndef TVM_RELAY_ATTRS_CALL_H_ +#define TVM_RELAY_ATTRS_CALL_H_ + +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Metadata for calls to TIR functions, useful for program analysis crossing Relay and TIR. + */ +struct CallLoweredAttrs : public tvm::AttrsNode { + /*! \brief The metadata attached to the call node. */ + Map metadata; + + TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { + TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call."); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_CALL_H_ diff --git a/include/tvm/relay/executor.h b/include/tvm/relay/executor.h new file mode 100644 index 000000000000..4f779e1dc0a4 --- /dev/null +++ b/include/tvm/relay/executor.h @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/executor.h + * \brief Object representation of Executor configuration and registry + */ +#ifndef TVM_RELAY_EXECUTOR_H_ +#define TVM_RELAY_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +template +class AttrRegistry; + +namespace relay { + +/*! + * \brief Executor information. + * + * This data structure stores the meta-data + * about executors which can be used to pass around information. + * + * \sa Executor + */ +class ExecutorNode : public Object { + public: + /*! \brief name of the Executor */ + String name; + /* \brief Additional attributes storing meta-data about the Executor. */ + DictAttrs attrs; + + /*! + * \brief Get an attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TObjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const Executor& executor) { + * auto value = executor->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("attrs", &attrs); + } + + bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { + return name == other->name && equal.DefEqual(attrs, other->attrs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(attrs); + } + + static constexpr const char* _type_key = "Executor"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); +}; + +/*! + * \brief Managed reference class to ExecutorNode. + * \sa ExecutorNode + */ +class Executor : public ObjectRef { + public: + /*! + * \brief Create a new Executor object using the registry + * \throws Error if name is not registered + * \param name The name of the executor. + * \param attrs Attributes for the executor. + * \return the new Executor object. + */ + TVM_DLL static Executor Create(String name, Map attrs); + + /*! + * \brief List all registered Executors + * \return the list of Executors + */ + TVM_DLL static Array ListExecutors(); + + /*! + * \brief List all options for a specific Executor + * \param name The name of the Executor + * \return Map of option name to type + */ + TVM_DLL static Map ListExecutorOptions(const String& name); + + /*! \brief specify container node */ + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); + + private: + /*! + * \brief Private Constructor + * \param name The executor name + * \param attrs Attributes to apply to this Executor node + */ + TVM_DLL Executor(String name, DictAttrs attrs) { + auto n = make_object(); + n->name = std::move(name); + n->attrs = std::move(attrs); + data_ = std::move(n); + } +}; + +/*! + * \brief Helper structure to register Executors + * \sa TVM_REGISTER_EXECUTOR + */ +class ExecutorRegEntry { + public: + /*! \brief Set name of the Executor to be the same as registry if it is empty */ + inline ExecutorRegEntry& set_name(); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \tparam ValueType The value type to be registered + */ + template + inline ExecutorRegEntry& add_attr_option(const String& key); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \param default_value The default value of the key + * \tparam ValueType The value type to be registered + */ + template + inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); + + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); + + private: + /*! \brief Internal storage of value types */ + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + std::unordered_map key2vtype_; + /*! \brief A hash table that stores the default value of each attr */ + std::unordered_map key2default_; + + /*! \brief Index used for internal lookup of attribute registry */ + uint32_t index_; + + // the name + std::string name; + + /*! \brief Return the index stored in attr registry */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief Return the name stored in attr registry */ + String AttrRegistryName() const { return name; } + + /*! \brief private constructor */ + explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} + + // friend class + template + friend class AttrRegistryMapContainerMap; + template + friend class tvm::AttrRegistry; + friend class Executor; +}; + +inline ExecutorRegEntry& ExecutorRegEntry::set_name() { + if (name.empty()) { + name = name; + } + return *this; +} + +template +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { + ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key + << "' has been set once"; + + using ValueNodeType = typename ValueType::ContainerType; + // NOTE: we could further update the function later. + uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + return *this; +} + +template +inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, + ObjectRef default_value) { + add_attr_option(key); + key2default_[key] = default_value; + return *this; +} + +// internal macros to make executor entries +#define TVM_EXECUTOR_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor + +/*! + * \def TVM_REGISTER_EXECUTOR + * \brief Register a new executor, or set attribute of the corresponding executor. + * + * \param ExecutorName The name of registry + * + * \code + * + * TVM_REGISTER_EXECUTOR("aot") + * .add_attr_option("my_option"); + * .add_attr_option("my_option_default", String("default")); + * + * \endcode + */ +#define TVM_REGISTER_EXECUTOR(ExecutorName) \ + TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName).set_name() +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_EXECUTOR_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index daad8514f9ff..aa341949b3d6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -295,6 +295,8 @@ class CallNode : public ExprNode { static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); + template + friend class runtime::ObjAllocatorBase; friend class Call; }; @@ -333,6 +335,11 @@ class Call : public Expr { class Let; /*! \brief A binding of a sub-network. */ class LetNode : public ExprNode { + protected: + // LetNode uses own deleter to indirectly call non-recursive destructor + Object::FDeleter saved_deleter_; + static void Deleter_(Object* ptr); + public: /*! \brief The variable we bind to */ Var var; @@ -364,10 +371,18 @@ class LetNode : public ExprNode { static constexpr const char* _type_key = "relay.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode); + template + friend class runtime::ObjAllocatorBase; + friend class Let; }; class Let : public Expr { public: + /*! + * \brief The destructor + */ + ~Let(); + /*! * \brief The constructor * \param var The variable that is bound to. @@ -639,5 +654,40 @@ class TempExpr : public Expr { }; } // namespace relay + +namespace runtime { + +template <> +template <> +inline ObjectPtr +ObjAllocatorBase::make_object() { + using Derived = SimpleObjAllocator; + using T = relay::LetNode; + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this)); + ptr->type_index_ = T::RuntimeTypeIndex(); + ptr->saved_deleter_ = Handler::Deleter(); + ptr->deleter_ = relay::LetNode::Deleter_; + return ObjectPtr(ptr); +} + +template <> +template <> +inline ObjectPtr +ObjAllocatorBase::make_object() { + using Derived = SimpleObjAllocator; + using T = relay::CallNode; + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this)); + ptr->type_index_ = T::RuntimeTypeIndex(); + ptr->saved_deleter_ = Handler::Deleter(); + ptr->deleter_ = relay::CallNode::Deleter_; + return ObjectPtr(ptr); +} + +} // namespace runtime + } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/relay/runtime.h b/include/tvm/relay/runtime.h new file mode 100644 index 000000000000..cc2ea4193ff2 --- /dev/null +++ b/include/tvm/relay/runtime.h @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/runtime.h + * \brief Object representation of Runtime configuration and registry + */ +#ifndef TVM_RELAY_RUNTIME_H_ +#define TVM_RELAY_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +template +class AttrRegistry; + +namespace relay { + +/*! + * \brief Runtime information. + * + * This data structure stores the meta-data + * about Runtimes which can be used to pass around information. + * + * \sa Runtime + */ +class RuntimeNode : public Object { + public: + /*! \brief name of the Runtime */ + String name; + /* \brief Additional attributes storing meta-data about the Runtime. */ + DictAttrs attrs; + + /*! + * \brief Get an attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TObjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const Runtime& runtime) { + * auto value = runtime->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("attrs", &attrs); + } + + bool SEqualReduce(const RuntimeNode* other, SEqualReducer equal) const { + return name == other->name && equal.DefEqual(attrs, other->attrs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(attrs); + } + + static constexpr const char* _type_key = "Runtime"; + TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object); +}; + +/*! + * \brief Managed reference class to RuntimeNode. + * \sa RuntimeNode + */ +class Runtime : public ObjectRef { + public: + /*! + * \brief Create a new Runtime object using the registry + * \throws Error if name is not registered + * \param name The name of the Runtime. + * \param attrs Attributes for the Runtime. + * \return the new Runtime object. + */ + TVM_DLL static Runtime Create(String name, Map attrs); + + /*! + * \brief List all registered Runtimes + * \return the list of Runtimes + */ + TVM_DLL static Array ListRuntimes(); + + /*! + * \brief List all options for a specific Runtime + * \param name The name of the Runtime + * \return Map of option name to type + */ + TVM_DLL static Map ListRuntimeOptions(const String& name); + + /*! \brief specify container node */ + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Runtime, ObjectRef, RuntimeNode); + + private: + /*! + * \brief Private Constructor + * \param name The Runtime name + * \param attrs Attributes to apply to this Runtime node + */ + TVM_DLL Runtime(String name, DictAttrs attrs) { + auto n = make_object(); + n->name = std::move(name); + n->attrs = std::move(attrs); + data_ = std::move(n); + } +}; + +/*! + * \brief Helper structure to register Runtimes + * \sa TVM_REGISTER_Runtime + */ +class RuntimeRegEntry { + public: + /*! \brief Set name of the Runtime to be the same as registry if it is empty */ + inline RuntimeRegEntry& set_name(); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \tparam ValueType The value type to be registered + */ + template + inline RuntimeRegEntry& add_attr_option(const String& key); + + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \param default_value The default value of the key + * \tparam ValueType The value type to be registered + */ + template + inline RuntimeRegEntry& add_attr_option(const String& key, ObjectRef default_value); + + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static RuntimeRegEntry& RegisterOrGet(const String& name); + + private: + /*! \brief Internal storage of value types */ + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + std::unordered_map key2vtype_; + /*! \brief A hash table that stores the default value of each attr */ + std::unordered_map key2default_; + + /*! \brief Index used for internal lookup of attribute registry */ + uint32_t index_; + + // the name + std::string name; + + /*! \brief Return the index stored in attr registry */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief Return the name stored in attr registry */ + String AttrRegistryName() const { return name; } + + /*! \brief private constructor */ + explicit RuntimeRegEntry(uint32_t reg_index) : index_(reg_index) {} + + // friend class + template + friend class AttrRegistryMapContainerMap; + template + friend class tvm::AttrRegistry; + friend class Runtime; +}; + +inline RuntimeRegEntry& RuntimeRegEntry::set_name() { + if (name.empty()) { + name = name; + } + return *this; +} + +template +inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key) { + ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key + << "' has been set once"; + + using ValueNodeType = typename ValueType::ContainerType; + // NOTE: we could further update the function later. + uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + return *this; +} + +template +inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key, + ObjectRef default_value) { + add_attr_option(key); + key2default_[key] = default_value; + return *this; +} + +// internal macros to make Runtime entries +#define TVM_RUNTIME_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::RuntimeRegEntry& __make_##Runtime + +/*! + * \def TVM_REGISTER_RUNTIME + * \brief Register a new Runtime, or set attribute of the corresponding Runtime. + * + * \param RuntimeName The name of registry + * + * \code + * + * TVM_REGISTER_RUNTIME("c") + * .add_attr_option("my_option"); + * .add_attr_option("my_option_default", String("default")); + * + * \endcode + */ +#define TVM_REGISTER_RUNTIME(RuntimeName) \ + TVM_STR_CONCAT(TVM_RUNTIME_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::relay::RuntimeRegEntry::RegisterOrGet(RuntimeName).set_name() +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_RUNTIME_H_ diff --git a/include/tvm/runtime/crt/graph_executor.h b/include/tvm/runtime/crt/graph_executor.h index eb68ff56d230..1353d8e06e6b 100644 --- a/include/tvm/runtime/crt/graph_executor.h +++ b/include/tvm/runtime/crt/graph_executor.h @@ -36,7 +36,7 @@ struct TVMModule; /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { - char func_name[120]; + char func_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 7b9a68063f16..366f4f1deed1 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -198,13 +198,20 @@ class ReportNode : public Object { */ String AsCSV() const; /*! \brief Create a human readable table of profiling metrics. - * \param aggregate Whether or not to join multiple calls to the same op into a single line. - * \param sort Whether or not to sort call frames by descending duration. If - * false and if `aggregate` is false, frames will be sorted by order of - * appearance in the program. Order is undefined if `sort` is false and - * `aggregate` is true. + * + * \param aggregate Whether or not to join multiple calls to the + * same op into a single line. + * + * \param sort Whether or not to sort call frames by descending + * duration. If false and if `aggregate` is false, frames will + * be sorted by order of appearance in the program. Order is + * undefined if `sort` is false and `aggregate` is true. + * + * \param compute_col_sums Whether or not to include sum totals for + * the Count, Duation, and Percent columns. + * */ - String AsTable(bool sort = true, bool aggregate = true) const; + String AsTable(bool sort = true, bool aggregate = true, bool compute_col_sums = true) const; /*! \brief Convert this report to JSON. * * Output JSON will be of this format: diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index 2cdd180730ec..6e564fd62380 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -131,6 +131,13 @@ class Executable : public ModuleNode { */ std::string GetBytecode() const; + /*! + * \brief Returns a description of all the contants in the executable in human-readable + * format. Not intended to be machine readable, but rather to help with debugging and + * diffing generated code. + */ + std::string GetConstants() const; + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 039b1894d7c4..ece73fcfda34 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -84,11 +84,11 @@ struct VMFunction { /*! \brief The size of the frame for this function */ Index register_file_size; /*! \brief The device type of each parameter for this function. */ - std::vector params_device_type; + std::vector params_device_type; VMFunction(const std::string& name, std::vector params, const std::vector& instructions, Index register_file_size, - const std::vector params_device_type = {}) + const std::vector params_device_type = {}) : name(name), params(params), instructions(instructions), diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h new file mode 100644 index 000000000000..facb74d6278e --- /dev/null +++ b/include/tvm/target/compilation_config.h @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.h + * \brief A helper class to collect all the targets in canonical form necessary for compilation. + * CAUTION: Preliminary, currently only used to support device planning, very likely to change. + */ + +#ifndef TVM_TARGET_COMPILATION_CONFIG_H_ +#define TVM_TARGET_COMPILATION_CONFIG_H_ + +#include + +namespace tvm { + +/*! + * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to + * compile a Relay module. All centralizes any setup and validation logic needed to transition + * from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly + * (eg a a list of \p Targets) to the configuration. + * + * CAUTION: This is subject to change as we rework compilation options in general. See + * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0028-command-line-registry-composition.md. + * So far this class is only focussed on carrying just the configuration needed by PlanDevices, + * and removing target-munging code duplication and inconsistencies between the three major build + * flows for the VM (relay/backend/vm/compile.cc), Graph/AOT (relay/backend/build_module.cc) and + * Interpreter (relay/backend/interpreter.cc). Over time we expect more global compiler + * configuration (eg for executor and runtime config, for system memory pool configuration, etc) + * to migrate into this class, and instances thereof to be attached to \p IRModules using a + * well-known attribute. + */ +class CompilationConfigNode : public Object { + public: + /*! + * \brief The legacy targets map, mapping device type to \p Targets. Does not include any + * entry for the host target. Intended to give a unique \p Target for every \p DLDeviceType, + * though we want to get rid of that limitation. + * + * CAUTION: Since keys are \p Integers they are compared by object equality not integer + * value. + * + * TODO(mbs): Remove once codegen updated for new target conventions. + */ + TargetMap legacy_target_map; + + /*! + * \brief The host target. Used for 'scalar' data and code (such as shapes and shape + * functions) and residual Relay expressions and data (such as conditionals and ADTs). + */ + Target host_target; + + /*! + * \brief Vector of all available targets for primitive operators. May contain a \p Target + * for the same device type as for the \p host_target, however the \p host_target should + * be preferred for all host computations and data. + */ + Array primitive_targets; + + /*! + * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular + * device. + */ + SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + + /*! \brief SEScope for the host. */ + SEScope host_se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all + * primitives are compiled for this target only. + * + * This is to support legacy passes which have not been adapted to hetrogeneous execution and + * rely on an implicit global \p Target to be in scope. + * + * TODO(mbs): Remove once all passes are 'hetrogeneous aware'. + */ + Target optional_homogeneous_target; + + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Returns a \p SEScope agreeing with \p se_scope on all its constrained fields, however: + * - If the target is null then it is filled in from the known available primitive targets by + * matching on device type. Fails if no such target is known. + * - The returned object is unique for the field values w.r.t. all other \p SEScopes returned + * by this method. + * + * We call the result the 'canonical' \p SEScope. Two canonical \p SEScopes are structurally + * equal if and only if they are pointer equal. + */ + SEScope CanonicalSEScope(const SEScope& se_scope) const; + + static constexpr const char* _type_key = "CompilationConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object) + + private: + /*! + * \brief Establishes the default \p SEScope for primitives and the \p SEScope for the host + * given: + * - the vector of available primitive \p Targets. + * - any host \p Target. + * - any "relay.fallback_device_type" attribute on \p pass_ctx. + * - whether the LLVM backend is available. + * If necessary, creates new default \p Targets to match the required devices. + * + * NOTE: The implementation is a bit convoluted since it tries to maintain backwards + * compatibility with legacy methods for conveying \p Targets. + * + * CAUTION: Recreated the primitive_targets so that they all have the given/constructed + * host_target as their host (cf CheckAndUpdateHostConsistency). + */ + void EstablishDefaultSEScopes(const transform::PassContext& pass_ctx); + + /*! + * \brief Returns a freshly constructed \p Target to represent \p device_type. + */ + static Target MakeDefaultTarget(DLDeviceType device_type); + + /*! + * \brief Return the \p Target to use for \p device_type. Fail if no such target exists. + */ + Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const; + + /*! + * \brief A cache of constructed SEScopes. + */ + mutable SEScopeCache se_scope_cache_; + + friend class CompilationConfig; +}; + +/*! + * \brief Managed reference class to \p CompilationConfig + * + * \sa CompilationConfig + */ +class CompilationConfig : public ObjectRef { + public: + /*! + * \brief Constructs the compilation config given the available \p Targets in the + * \p legacy_target_map_arg and an optional \p optional_host_target_arg. May use + * 'relay.fallback_device_type' and the availability of the LLVM compilation module + * to decide on appropriate default devices. + */ + TVM_DLL CompilationConfig(const transform::PassContext& pass_ctx, TargetMap legacy_target_map_arg, + Target optional_host_target_arg); + + TVM_DEFINE_OBJECT_REF_METHODS(CompilationConfig, ObjectRef, CompilationConfigNode); +}; + +} // namespace tvm + +#endif // TVM_TARGET_COMPILATION_CONFIG_H_ diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h new file mode 100644 index 000000000000..981a0b85ab13 --- /dev/null +++ b/include/tvm/target/se_scope.h @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.h + * \brief A compile time representation for a Storage or Execution Scope. + */ + +#ifndef TVM_TARGET_SE_SCOPE_H_ +#define TVM_TARGET_SE_SCOPE_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { + +/*! + * Abstract label for an area of memory. + * + * Currently uninterpreted and arbitrary. Likely to be replaced by a structured representation + * of a memory pool in the future. Please try to use this alias instead of String to aid future + * code migration. + */ +using MemoryScope = String; + +/*! + * \brief Describes at compile time where data is to be stored down to the device and memory + * scope level, or where execution is to take place, down to the device level. It is a quadruple of: + * - A \p device_type (\p DLDeviceType). May be kInvalidDeviceType if unconstrained. + * - A \p virtual_device_id (\p int). This allows us to distinguish distinct devices + * with the same \p Target, for example in a multi-GPU system. May be -1 if unconstrained. + * See "Virtual Devices" below. + * - A \p target (\p Target) describing how to compile code for the intended device. May be null + * if unconstrained. + * - A \p memory_scope (\p MemoryScope, which is currently just \p String) describing which memory + * area is to be used to hold data. May be "" if unconstrained. See "Memory Scopes and Devices" + * below. + * + * Some or all of these fields may be unconstrained, signaling that device planning is free to + * choose a value consistent with the whole program. However if a \p target is given then the \p + * device_type must equal \p target->kind->device_type. + * + * Note that currently we assume if a function returns its result on a particular device + * then the function body is also executed on that device. See the overview comment in + * src/relay/transforms/device_planner.cc for more details. + * + * By 'data' we include both tensors and additional supporting datastructures such as shapes, + * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside + * on a 'CPU'-like device with good support for scalars. + * + * By 'execution' we include both (fused) primitive operators, and all the Relay expressions + * surrounding them which coordinates data and control flow. Again, typically non-primitive + * operators must be executed on a 'CPU'-like device with good support for control flow. + * + * Since TVM targets such a wide range of systems it is not possible for \p SEScope to impose + * much semantics on these fields, particularly for \p virtual_device_id and \p memory_scope. + * Instead we assume downstream passes and codegen will interpret an validate these fields + * appropriately. + * + * Targets vs Devices + * ------------------ + * Generally \p Targets (a compile-time only datastructue) describe compiler options for a specific + * microarchitecture and toolchain, while \p Devices (a runtime datastructure also available at + * compile time) describe a physical device on the target system. Obviously the target must agree + * with the device's microarchitecture, but we otherwise don't impose any constraints between them: + * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf + * out of a particular primitive. + * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. + * + * Traditionally TVM assumes at most one \p Target per \p DLDeviceType. We are moving away from that + * assumption. + * + * Virtual vs Physical Devices + * --------------------------- + * The \p virtual_device_id may be used by downstream passes or the runtime to help decide which + * \p device_id to use for a particular physical runtime \p Device. For example: + * - Some runtimes may support passing in an array of actual `device` specifications, and the + * \p virtual_device_id can be used at runtime as an index into that array. + * - Some runtimes may support dynamically allocating computations to physical devices. On these + * systems a large space of \p virtual_device_ids could be used at compile time, even though + * at runtime only a few physical devices will be present. + * + * The \p virtual_device_id may also be left unconstrained if not needed. + * + * Memory Scopes and Devices + * ------------------------- + * Multi-device systems can have complex memory hierarchies. For example + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCPU, 1, "llvm", "global") + * \endcode + * could denote: + * - The same memory area accessible from two separate CPUs without any CPU affinity; + * - Distinct memory areas in a NUMA architecture for which cross-device access is handled + * by the memory system; + * - Outright distinct memory areas, where one device cannot directly address the memory of + * another. + * + * Similarly: + * \code + * (kDLCPU, 0, "llvm", "global") + * \endcode + * and + * \code + * (kDLCUDA, 0, "cuda", "host") + * \endcode + * could denote the same memory area, but with very different access costs. + * + * Furthermore, not all memory scopes are accessible to all devices, and it is possible for + * a memory scope to only be accessible to a device when code is compiled with particular + * \p Target options. + * + * \p SEScopes themselves have no system-level understanding. Currently device planning will + * simply insert "device_copy" operators wherever \p SEScopes are not exactly pointwise equal. + * We may revisit this in the future as the work on memory pools matures. + * + * Joining and Defaulting + * ---------------------- + * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees + * with both join arguments. Eg: + * \code + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) + * => (kDLCPU, 3, "llvm", "global") + * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "local)) + * => null (no join possible) + * \endcode + * + * Related to 'join' is 'default', which only takes constrained fields from the rhs when the + * lhs is unconstrained: + * \code + * Default(kDLCPU, -1, "llvm", "local"), (kDLCPU, 3, null, "global")) + * => (kDLCPU, 3, "llvm", "local") + * \endcode + * + * These operations are needed during device planning. + * + */ +class SEScopeNode : public AttrsNode { + public: + /*! + * \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is + * known then this will be equal to \p target->kind->device_type. If \p target is null then the + * target is to be determined later. + * + * This is needed to support the legacy "on_device" and "device_copy" calls which only allow + * a \p DLDeviceTypes (as an integer) to be given. + * + * kInvalidDeviceType denotes unconstrained. + */ + int device_type_int; + + DLDeviceType device_type() const { return static_cast(device_type_int); } + + /*! + * \brief The device identifier for the virtual device. This must be resolved to a physical + * device identifier either during compilation or at runtime. + * + * -1 denotes unconstrained. + */ + int virtual_device_id; + + /*! + * \brief The \p Target describing how to compile for the virtual device. + * + * Null denotes unconstrained. Note that if a target later becomes known for this \p SEScope + * then it must be consistent with the \p device_type if already known. This is enforced by the + * Join and Default methods. + */ + Target target; + + /*! + * \brief The scope of memory w.r.t. the virtual device which holds data. + * + * Empty denotes unconstrained. + */ + MemoryScope memory_scope; + + /*! + * \brief Returns true if scope is fully unconstrained, ie no target/device type, device id + * or memory scope is specified. + */ + bool IsFullyUnconstrained() const { + return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 && + memory_scope.empty(); + } + + /*! + * \brief Returns true if scope is fully constrained, ie target, device id and memory scope are + * all specified. + */ + bool IsFullyConstrained() const { + return target.defined() && virtual_device_id != -1 && !memory_scope.empty(); + } + + /*! + * \brief Returns the (virtual) \p Device implied by this \p SEScope. Both the \p device_type and + * \p virtual_device_must be constrained. The returned \p Device may not correspond to any + * physical device available at compile time or even runtime: see "Virtual vs Physical Devices" + * above. + */ + Device ToDevice() const { + ICHECK(device_type() != kInvalidDeviceType); + ICHECK(virtual_device_id != -1); + Device device; + device.device_type = device_type(); + device.device_id = virtual_device_id; + return device; + } + + TVM_DECLARE_ATTRS(SEScopeNode, "SEScope") { + TVM_ATTR_FIELD(device_type_int) + .describe("The type of the virtual device.") + .set_default(kInvalidDeviceType); + TVM_ATTR_FIELD(virtual_device_id) + .describe("The device id of the virtual device.") + .set_default(-1); + TVM_ATTR_FIELD(target) + .describe("The target describing how to compile for the virtual device.") + .set_default(Target()); + TVM_ATTR_FIELD(memory_scope) + .describe("The area of memory w.r.t. the virtual device where data is stored.") + .set_default(""); + } + + friend class SEScope; +}; + +/*! + * \brief Managed reference class to \p SEScopeNode. + * + * \sa SEScopeNode. + */ +class SEScope : public ObjectRef { + public: + /*! + * \brief Construct an SEScope. + * \param device_type The device type for the virtual device, or kInvalidDeviceType if + * unconstrained. If \p target is defined then must match its \p target->kind->device_type. + * \param virtual_device_id The device id for the virtual device, or -1 if unconstrained. + * \param target The target describing how to compile for the virtual device, or null if + * unconstrained. + * \param memory_scope The memory scope w.r.t. the virtual device which holds data, or "" if + * unconstrained. + * \return The SEScope + */ + explicit SEScope(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique fully unconstrained \p SEScope. */ + static SEScope FullyUnconstrained(); + + /*! + * \brief Returns the \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * The target and memory scope will be unconstrained. + */ + static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + ICHECK_GT(device_type, 0); + return SEScope(device_type, virtual_device_id); + } + static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type), virtual_device_id); + } + static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + return ForDeviceType(static_cast(device_type->value), virtual_device_id); + } + + /*! \brief Returns the \p SEScope for \p device. */ + static SEScope ForDevice(const Device& device) { + return ForDeviceType(device.device_type, device.device_id); + } + + /*! \brief Returns the \p SEScope for \p device and \p target. */ + static SEScope ForDeviceAndTarget(const Device& device, Target target) { + return SEScope(device.device_type, device.device_id, std::move(target)); + } + + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ + TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, + MemoryScope memory_scope) { + return SEScope(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); + } + + /*! + * \brief Returns the 'join' of \p lhs and \p rhs. The result will agree pointwise with + * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such + * join exists, ie there's disagreement on at least one constrained field. + */ + static Optional Join(const SEScope& lhs, const SEScope& rhs); + + /*! + * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any + * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. + */ + static SEScope Default(const SEScope& lhs, const SEScope& rhs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + + friend class SEScopeCache; // Private implementation helper. +}; + +/*! + * \brief A cache of \p SEScopes. This can be used: + * - To avoid ending up with lots of identical instances, since the space of SEScopes for any + * one compilation is very small but the number of points they need to be constructed can + * be very large (eg during device planning). + * - So we can assume \p SEScopes are pointer equal if and only if they are structurally equal. + * This simplifies the unification of 'device domains' which are built on \p SEScopes. + */ +class SEScopeCache { + public: + /*! \brief Returns the unique \p SEScope representing given fields. */ + SEScope Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); + + /*! \brief Returns the unique \p SEScope structurally equal to the given \p se_scope. */ + SEScope Unique(const SEScope& scope); + + private: + /*! \brief Already constructed SEScopes. */ + std::unordered_set cache_; +}; + +} // namespace tvm + +#endif // TVM_TARGET_SE_SCOPE_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 64a1023158e1..21760bdc8dbf 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -66,6 +66,15 @@ class TargetNode : public Object { /*! \return The Optional typed target host of the TargetNode */ TVM_DLL Optional GetHost() const; + /*! + * \brief Returns a human readable representation of \p Target which includes all fields, + * especially the host. Useful for diagnostic messages and debugging. + * + * TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently + * code depends on str() and << being the same. + */ + String ToDebugString() const; + void VisitAttrs(AttrVisitor* v) { v->Visit("kind", &kind); v->Visit("tag", &tag); @@ -110,7 +119,12 @@ class TargetNode : public Object { /*! \brief Get the keys for this target as an unordered_set of string */ TVM_DLL std::unordered_set GetLibs() const; + bool SEqualReduce(const TargetNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; + static constexpr const char* _type_key = "Target"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object); private: @@ -179,6 +193,9 @@ class Target : public ObjectRef { */ TVM_DLL void ExitWithScope(); }; + +using TargetMap = Map; + /*! * \brief Check and update host field of the given legacy target and target host pair. * Note that this function is for legacy target api compatibility issue only, not @@ -187,22 +204,24 @@ class Target : public ObjectRef { * \param host The pointer to a Target typed object for target host to be updated */ void CheckAndUpdateHostConsistency(Target* target, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with values being Target objects + * \param target_map The pointer to a Map objects with values being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(TargetMap* target_map, Target* host); + /*! * \brief Check and update host field of the given legacy heterogeneous targets and * target host.Note that this function is for legacy target api compatibility issue only, * not recommended for other use. - * \param target The pointer to a Map objects with keys being Target objects + * \param ir_modules The pointer to a Map objects with keys being Target objects * \param host The Target typed object for target host to be updated */ -void CheckAndUpdateHostConsistency(Map* target, Target* host); +void CheckAndUpdateHostConsistency(Map* ir_modules, Target* host); } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e4a3d3d1e21b..e482a18c4a5b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -289,6 +289,13 @@ constexpr const char* kIsEntryFunc = "tir.is_entry_func"; */ constexpr const char* kLinkedParams = "tir.linked_params"; +/*! + * \brief Mark the function as the global function called from the host. + * + * Type: Integer + */ +constexpr const char* kIsGlobalFunc = "tir.is_global_func"; + } // namespace attr } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c4aa1c953ab6..ffd860d84cf3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -194,6 +194,16 @@ class ScheduleNode : public runtime::Object { */ virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) = 0; + /*! + * \brief Sample the factors to perfect tile a specific loop + * \param loop_rv The loop to be tiled + * \param n The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ + virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -210,6 +220,30 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of a specific scope + * \param block_rv The block where the scope is rooted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of under a specific loop + * \param loop_rv The loop under which collecting is conducted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; + /*! + * \brief Get the producer of a specific block + * \param block_rv The block in the query + * \return A list of blocks, the producers of the given block + */ + virtual Array GetProducers(const BlockRV& block_rv) = 0; + /*! + * \brief Get the consumers of a specific block + * \param block_rv The block to be queried + * \return A list of blocks, the consumers of the given block + */ + virtual Array GetConsumers(const BlockRV& block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Fuse a list of consecutive loops into one. It requires: diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..40a0d1ab2f74 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -109,6 +109,12 @@ class Var : public PrimExpr { * \return the new Var copy */ TVM_DLL Var copy_with_suffix(const String& suffix) const; + /*! + * \brief Make a new copy of the variable with specified dtype + * \param dtype The specified dtype + * \return The new variable + */ + TVM_DLL Var copy_with_dtype(DataType dtype) const; /*! * \brief Get pointer to the internal value. diff --git a/licenses/LICENSE.cutlass.txt b/licenses/LICENSE.cutlass.txt new file mode 100644 index 000000000000..64a49d680b1e --- /dev/null +++ b/licenses/LICENSE.cutlass.txt @@ -0,0 +1,23 @@ +Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/setup.py b/python/setup.py index 1b2a9d3ee965..5d21af6b5878 100644 --- a/python/setup.py +++ b/python/setup.py @@ -62,6 +62,13 @@ def get_lib_path(): libs.append(candidate_path) break + # Add microTVM template projects + for name in lib_path: + candidate_path = os.path.join(os.path.dirname(name), "microtvm_template_projects") + if os.path.isdir(candidate_path): + libs.append(candidate_path) + break + else: libs = None diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 57374c54b297..3f2f277d2926 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -23,7 +23,8 @@ # top-level alias # tvm._ffi -from ._ffi.base import TVMError, __version__ +from ._ffi.base import TVMError, __version__, _RUNTIME_ONLY + from ._ffi.runtime_ctypes import DataTypeCode, DataType from ._ffi import register_object, register_func, register_extension, get_global_func @@ -68,7 +69,7 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel -if support.libinfo().get("USE_MICRO", "OFF") == "ON": +if not _RUNTIME_ONLY and support.libinfo().get("USE_MICRO", "OFF") == "ON": from . import micro # NOTE: This file should be python2 compatible so we can diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 0eacd1a1f667..6f35e021daf8 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -58,7 +58,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): opt_level=opt_level, config={ "relay.backend.use_auto_scheduler": True, - "relay.backend.disable_compile_engine_cache": True, }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): @@ -165,7 +164,8 @@ class TracingMode: """Two modes for tracing""" EXTRACT_TASK = 0 # trace all topi calls to extract tasks - EXTRACT_COMPLEX_TASK_ONLY = 1 # same as EXTRACT_TASK but ignore the task without complex ops + # same as EXTRACT_TASK but ignore the task without complex ops + EXTRACT_COMPLEX_TASK_ONLY = 1 PREPARE_LAYOUT_REWRITE = 2 # trace topi calls to prepare layout rewrite diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 723e7fa77006..7299875bf28d 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -142,7 +142,7 @@ def _traverse_expr(node): params.append(free_var) call = relay.Call(node.op, params, node.attrs) mod = tvm.IRModule.from_expr(relay.Function(params, call)) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() tracing_target = _replace_device_with_tracing(tvm_target) build_thread = threading.Thread( target=relay.build, args=(mod, tracing_target, None, None) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 714dd540d3ab..4716116a1b83 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -127,12 +127,12 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No assert isinstance( mod, tvm.IRModule ), "only support relay Module or Function to be tuned" - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, param)) build_thread.start() build_thread.join() - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() # Clear the warning message cache in FallbackContext if isinstance(DispatchContext.current, FallbackContext): DispatchContext.current.memory = {} diff --git a/python/tvm/contrib/cutlass/__init__.py b/python/tvm/contrib/cutlass/__init__.py new file mode 100644 index 000000000000..69d3e9c4bd7c --- /dev/null +++ b/python/tvm/contrib/cutlass/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""BYOC support for CUTLASS.""" +from .build import tune_cutlass_kernels, build_cutlass_kernels, build_cutlass_kernels_vm diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py new file mode 100644 index 000000000000..615b9003adc9 --- /dev/null +++ b/python/tvm/contrib/cutlass/build.py @@ -0,0 +1,336 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Driver for partitioning and building a Relay module for CUTLASS offload.""" +import logging +import os +import multiprocessing +import tvm +from tvm import runtime, relay +from tvm.contrib.nvcc import find_cuda_path, get_cuda_version +from .gen_gemm import CutlassGemmProfiler + +logger = logging.getLogger("cutlass") + + +def _get_cutlass_path(): + tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../") + cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass") + assert os.path.exists( + cutlass_path + ), """The CUTLASS root directory not found in {}. + Currently, using CUTLASS requires building TVM from source.""".format( + cutlass_path + ) + return cutlass_path + + +def _get_cutlass_compile_options(sm, threads): + cutlass_root = _get_cutlass_path() + cutlass_include = os.path.join(cutlass_root, "include") + cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") + + kwargs = {} + kwargs["cc"] = "nvcc" + kwargs["options"] = [ + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), + "-Xcompiler=-fPIC", + "-Xcompiler=-Wconversion", + "-Xcompiler=-fno-strict-aliasing", + "-O3", + "-std=c++14", + "-I" + cutlass_include, + "-I" + cutlass_util_include, + ] + cuda_path = find_cuda_path() + cuda_ver = get_cuda_version(cuda_path) + if cuda_ver >= 11.2: + ncpu = multiprocessing.cpu_count() if threads < 0 else threads + kwargs["options"].append("-t %d" % ncpu) + return kwargs + + +class GemmAnnotator(tvm.relay.ExprVisitor): + """Annotates partitioned functions with shape and dtype information.""" + + def __init__(self): + super().__init__() + self.signature = {} + + def visit_call(self, call): + op = call.op + if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + self.signature["op_type"] = op.attrs["Composite"] + for i, arg in enumerate(op.params): + self.signature["arg%d_shape" % i] = arg.checked_type.shape + self.signature["arg%d_dtype" % i] = arg.checked_type.dtype + self.signature["ret_shape"] = op.ret_type.shape + self.signature["ret_dtype"] = op.ret_type.dtype + + +def select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing +): + """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic + workloads.""" + if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): + out = cutlass_profiler.get_default(out_dtype, batched=batched) + logger.info("Picked the default kernel %s", out["name"]) + else: + out = cutlass_profiler.profile( + MM, + NN, + KK, + out_dtype, + batched=batched, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, + ) + if profile_all: + logger.info("The best kernel is %s", out["name"]) + else: + logger.info("Picked the first kernel found %s", out["name"]) + return out + + +def handle_batch_matmul( + cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing +): + """Profile and select a kernel for batch_matmul op workload.""" + MM = arg0_shape[1] + KK = arg0_shape[2] + NN = arg1_shape[1] + + out = select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing + ) + + if op_type == "cutlass.batch_matmul": + cutlass_op_def = out["opdef"] + else: + raise ValueError("%s pattern is not implemented." % op_type) + + return { + "batch": arg0_shape[0], + "batch_stride_A": arg0_shape[1] * arg0_shape[2], + "batch_stride_B": arg1_shape[1] * arg1_shape[2], + "batch_stride_C": arg0_shape[1] * arg1_shape[1], + "cutlass_op_def": cutlass_op_def, + "cutlass_op_name": out["name"], + } + + +def handle_dense( + cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing +): + """Profile and select a kernel for dense op workload.""" + MM = arg0_shape[0] + KK = arg0_shape[1] + NN = arg1_shape[0] + + out = select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing + ) + + if op_type == "cutlass.dense": + cutlass_op_def = out["opdef"] + elif op_type == "cutlass.dense_bias": + cutlass_op_def = out["opdef_bias"] + elif op_type == "cutlass.dense_bias_relu": + cutlass_op_def = out["opdef_bias_relu"] + elif "cutlass.dense_bias_gelu" in op_type: + cutlass_op_def = out["opdef_bias_gelu"] + else: + raise ValueError("%s pattern is not implemented." % op_type) + + return { + "cutlass_op_def": cutlass_op_def, + "cutlass_op_name": out["name"], + } + + +def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): + """Given a module partitioned for CUTLASS offloading, profile each workload to select which + kernels to emit. + + Parameters + ---------- + mod : IRModule + The Relay module with cutlass partitions. + + sm : int + An integer specifying the compute capability. For example, 75 for Turing and + 80 or 86 for Ampere. + + profile_all : bool + Whether or not profile all candidate kernels, or stop profiling after + the first applicable kernel is found. + + use_multiprocessing : bool + Whether or not compile profiler executables for different kernels in parallel. + + tmp_dir : string, optional + A temporary directory where intermediate compiled artifacts will be stored. + + Returns + ------- + mod : IRModule + The updated module annotated with cutlass profiling information. + + num_cutlass_partition : int + The number of partitioned functions created for CUTLASS. + """ + cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + num_cutlass_partition = 0 + for var in mod.get_global_vars(): + fun_name = var.name_hint + func = mod[fun_name] + annotator = GemmAnnotator() + if "cutlass" in fun_name: + num_cutlass_partition += 1 + annotator.visit(func) + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} + new_attrs.update(annotator.signature) + new_attrs.update(func.attrs) + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + + if "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + profile_all, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + profile_all, + use_multiprocessing, + ) + ) + else: + raise ValueError("%s unsupported composite" % op_type) + + if new_attrs["cutlass_op_name"].find("_tn_align") > 0: + new_attrs["lda"] = "K" + new_attrs["ldb"] = "K" + new_attrs["ldc"] = "N" + else: + raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"]) + + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + new_func = relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) + mod.update_func(var, new_func) + + return mod, num_cutlass_partition + + +def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1): + """Compile CUTLASS kernels in lib and return the runtime module ready to run. + + Parameters + ---------- + lib : GraphExecutorFactoryModule + The output from relay.build containing compiled host code and non-cutlass kernels. + + sm : int + An integer specifying the compute capability. For example, 75 for Turing and + 80 or 86 for Ampere. + + tmp_dir : string, optional + A temporary directory where intermediate compiled artifacts will be stored. + + lib_path : string, optional + The path to a shared library which will be generated as the result of the build process. + + threads : int, optional + The number of threads to use for compiling generated kernels. Only available for + CUDA 11.2 or later. Use all physical cores by default. + + Returns + ------- + updated_lib : runtime.Module + The updated module with compiled cutlass kernels. + """ + kwargs = _get_cutlass_compile_options(sm, threads) + lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) + return runtime.load_module(lib_path) + + +def build_cutlass_kernels_vm( + vm_exec, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro", threads=-1 +): + """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. + + Parameters + ---------- + vm_exec : vm.Executable + The output from relay.vm.compile containing compiled host code and non-cutlass kernels. + + sm : int + An integer specifying the compute capability. For example, 75 for Turing and + 80 or 86 for Ampere. + + tmp_dir : string, optional + A temporary directory where intermediate compiled artifacts will be stored. + + lib_path : string, optional + The path to a shared library which will be generated as the result of the build process. + + vmcode_path : string, optional + The path where the VM bytecode will be serialized to. + + threads : int, optional + The number of threads to use for compiling generated kernels. Only available for + CUDA 11.2 or later. Use all physical cores by default. + + Returns + ------- + updated_vm_exec: vm.Executable + The updated exectuable with compiled cutlass kernels. + """ + code, lib = vm_exec.save() + kwargs = _get_cutlass_compile_options(sm, threads) + lib_path = os.path.join(tmp_dir, lib_path) + vmcode_path = os.path.join(tmp_dir, vmcode_path) + lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) + with open(vmcode_path, "wb") as fo: + fo.write(code) + lib = tvm.runtime.load_module(lib_path) + code = bytearray(open(vmcode_path, "rb").read()) + return tvm.runtime.vm.Executable.load_exec(code, lib) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py new file mode 100644 index 000000000000..4673b4bdea65 --- /dev/null +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS GEMM kernels.""" +from .library import * + + +class GemmOperation: + """Describes various attributes for instantiating GEMM kernels.""" + + def __init__( + self, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + ): + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + def accumulator_type(self): + return self.tile_description.math_instruction.element_accumulator + + def short_math_name(self): + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + inst_shape = "" + intermediate_type = "" + + if ( + self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp + or self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp + ): + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if ( + self.tile_description.math_instruction.element_a != self.A.element + and self.tile_description.math_instruction.element_a + != self.tile_description.math_instruction.element_accumulator + ): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % ( + self.short_math_name(), + inst_shape, + intermediate_type, + "gemm", + ) + + def extended_name(self): + """ Append data types if they differ from compute type. """ + if ( + self.C.element != self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element == self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = substitute_template( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + def layout_name(self): + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, + and layout. + """ + threadblock = self.tile_description.procedural_name() + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + return substitute_template( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment, + }, + ) + + def leading_dim(self): + """ lda, ldb, ldc, according to the leading dimension. """ + if self.A.layout == LayoutType.RowMajor: + lda = "K" + elif self.A.layout == LayoutType.ColumnMajor: + lda = "M" + else: + ValueError("The layout of A is not implemented.") + + if self.B.layout == LayoutType.RowMajor: + ldb = "N" + elif self.B.layout == LayoutType.ColumnMajor: + ldb = "K" + else: + ValueError("The layout of B is not implemented.") + + if self.C.layout == LayoutType.RowMajor: + ldc = "N" + elif self.C.layout == LayoutType.ColumnMajor: + ldc = "M" + else: + ValueError("The layout of B is not implemented.") + + return substitute_template( + "int lda = ${lda_val};\n\tint ldb = ${ldb_val};\n\tint ldc = ${ldc_val};\n", + { + "lda_val": lda, + "ldb_val": ldb, + "ldc_val": ldc, + }, + ) + + +class EmitGemmInstance: + """ Responsible for emitting a CUTLASS template definition.""" + + def __init__(self): + self.epilogue_default = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + self.epilogue_no_beta_scaling = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >""" + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue}, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial} + ${math_operation} + >; +""" + + def emit(self, operation, no_beta_scaling=False, batched=False): + """Instantiate a GEMM kernel from given `operation`.""" + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + epilogue_vector_length = ( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + // DataTypeSize[operation.C.element] + ) + values = { + "operation_name": operation.procedural_name(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + } + + values["kernel_name"] = "GemmBatched" if batched else "Gemm" + values["split_k_serial"] = "" if batched else "false," + + gemm_template = substitute_template( + self.gemm_template, + { + "epilogue": self.epilogue_no_beta_scaling + if no_beta_scaling + else self.epilogue_default + }, + ) + return substitute_template(gemm_template, values) diff --git a/python/tvm/contrib/cutlass/gemm_profiler.py b/python/tvm/contrib/cutlass/gemm_profiler.py new file mode 100644 index 000000000000..13679cd05c42 --- /dev/null +++ b/python/tvm/contrib/cutlass/gemm_profiler.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel, invalid-name +"""Instantiate a C++ source for profiling CUTLASS kernels.""" + + +class GemmProfilerEmitter(object): + """Emit a C++ source for profiling CUTLASS kernels.""" + + def __init__(self): + from jinja2 import Template + + self.template = Template( + """ +#include +#include +#include +#include + +#include "cuda_runtime.h" +#include "cutlass/gemm/device/gemm.h" + +#define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \\ + << std::endl; \\ + exit(EXIT_FAILURE); \\ + } \\ + } + +#define CUDA_CHECK(status) \\ + { \\ + cudaError_t error = status; \\ + if (error != cudaSuccess) { \\ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \\ + << " at line: " << __LINE__ << std::endl; \\ + exit(EXIT_FAILURE); \\ + } \\ + } + +template +cudaError_t CutlassGemmRCR( + int M, + int N, + int K, + DTypeC alpha, + DTypeA const *A, + int lda, + DTypeB const *B, + int ldb, + DTypeC beta, + DTypeC *C, + int ldc) { + using namespace std::chrono; + {{OperatorDef}} + Operation_{{OperatorName}} gemm_operator; + Operation_{{OperatorName}}::Arguments args({M, N, K}, + {A, lda}, + {B, ldb}, + {C, ldc}, + {C, ldc}, + {alpha, beta}); + cutlass::Status status = gemm_operator(args); + CUTLASS_CHECK(status) + + high_resolution_clock::time_point t1 = high_resolution_clock::now(); + for (int i = 0; i < 100; ++i) { + status = gemm_operator(args); + } + cudaDeviceSynchronize(); + high_resolution_clock::time_point t2 = high_resolution_clock::now(); + duration time_span = duration_cast>(t2 - t1); + std::cout << time_span.count() << std::endl; + return cudaSuccess; +} + + +template +cudaError_t AllocateMatrix(DType **matrix, int ldm, int rows, int columns, int seed = 0) { + cudaError_t result; + + size_t sizeof_matrix = sizeof(DType) * rows * columns; + + // Allocate device memory. + result = cudaMalloc(reinterpret_cast(matrix), sizeof_matrix); + + if (result != cudaSuccess) { + std::cerr << "Failed to allocate matrix: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + // Clear the allocation. + result = cudaMemset(*matrix, 0, sizeof_matrix); + + if (result != cudaSuccess) { + std::cerr << "Failed to clear matrix device memory: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + if (result != cudaSuccess) { + std::cerr << "Failed to initialize matrix: " + << cudaGetErrorString(result) << std::endl; + return result; + } + + return result; +} + +template +cudaError_t TestCutlassGemm(int M, int N, int K, DTypeC alpha, DTypeC beta) { + cudaError_t result; + + {{LeadingDim}} + // size_t sizeof_C = sizeof(DTypeC) * ldc * N; + DTypeA *A; + DTypeB *B; + DTypeC *C_cutlass; + result = AllocateMatrix(&A, lda, M, K, 0); + if (result != cudaSuccess) { + return result; + } + result = AllocateMatrix(&B, ldb, K, N, 17); + if (result != cudaSuccess) { + cudaFree(A); + return result; + } + result = AllocateMatrix(&C_cutlass, ldc, M, N, 101); + if (result != cudaSuccess) { + cudaFree(A); + cudaFree(B); + return result; + } + result = CutlassGemmRCR(M, N, K, alpha, A, lda, B, ldb, + beta, C_cutlass, ldc); + if (result != cudaSuccess) { + std::cerr << "CUTLASS GEMM kernel failed: " + << cudaGetErrorString(result) << std::endl; + cudaFree(C_cutlass); + cudaFree(B); + cudaFree(A); + + return result; + } + cudaFree(C_cutlass); + cudaFree(B); + cudaFree(A); + return cudaSuccess; +} + +int main(int argc, const char *arg[]) { + int problem[3] = { 4096, 4096, 4096 }; + for (int i = 1; i < argc && i < 4; ++i) { + std::stringstream ss(arg[i]); + ss >> problem[i - 1]; + } + float scalars[2] = { 1, 0 }; + cudaError_t result = TestCutlassGemm< {{DTypeA}}, {{DTypeB}}, {{DTypeC}}>( + problem[0], // GEMM M dimension + problem[1], // GEMM N dimension + problem[2], // GEMM K dimension + static_cast<{{DTypeC}}>(scalars[0]), // alpha + static_cast<{{DTypeC}}>(scalars[1]) // beta + ); + return result == cudaSuccess ? 0 : -1; +} +""" + ) + + def emit(self, op_name, op_def, dtype_a, dtype_b, dtype_c, ld): + src = self.template.render( + OperatorName=op_name, + OperatorDef=op_def, + DTypeA=dtype_a, + DTypeB=dtype_b, + DTypeC=dtype_c, + LeadingDim=ld, + ) + return src diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py new file mode 100644 index 000000000000..1ed4bfe5fc4c --- /dev/null +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Kernel generator and profiler for CUTLASS.""" +import logging +import os +import re +import tempfile +import subprocess +import multiprocessing +from .gemm_operation import GemmOperation, EmitGemmInstance +from .gemm_profiler import GemmProfilerEmitter +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + DataTypeTag, + LayoutType, + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, +) + +logger = logging.getLogger("cutlass") + + +def create_gemm_operator( + layouts, + tile_descriptions, + data_type, + alignment_constraints, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + batched=False, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + kernel_emitter = EmitGemmInstance() + profiler_emitter = GemmProfilerEmitter() + + element_a, element_b, element_c, element_epilogue = data_type + + if batched: + swizzling_functor = SwizzlingFunctor.Batched + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + op_entry = {} + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + op_bias = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationBias, + swizzling_functor, + ) + op_bias_relu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationRelu, + swizzling_functor, + ) + op_bias_gelu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationGelu, + swizzling_functor, + ) + + kernel_emitter = EmitGemmInstance() + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["opdef"] = kernel_emitter.emit(op, batched=batched) + op_entry["opdef_bias"] = kernel_emitter.emit( + op_bias, no_beta_scaling=True, batched=batched + ) + op_entry["opdef_bias_relu"] = kernel_emitter.emit( + op_bias_relu, no_beta_scaling=True, batched=batched + ) + op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched) + op_entry["src"] = profiler_emitter.emit( + op.procedural_name(), + kernel_emitter.emit(op, batched=False), + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + op_entry["runtime"] = 9999999 + ret.append(op_entry) + return ret + + +def generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, batched=False +): + """Common kernel generator to be used by archtecture specific generators.""" + ops = [] + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for math_inst in math_instructions: + tile_descriptions = get_tile_descriptions(math_inst) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + out = create_gemm_operator( + layouts, tile_descriptions, data_type, alignment_constraints, batched=batched + ) + + ops.extend(out) + + return ops + + +def generate_sm75_tensor_op_1688(out_dtype, batched=False): + """Generate GEMM kernels for Turing.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2, 1] + + def get_tile_descriptions(math_inst): + min_cc = 75 + max_cc = 1024 + return [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, batched + ) + + +def generate_sm80_tensor_op_16816(out_dtype, batched=False): + """Generate GEMM kernels for Ampere.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] + + alignment_constraints = [8, 4, 2] + + def get_tile_descriptions(math_inst): + min_cc = 80 + max_cc = 1024 + max_cc_smem_limited = 80 + return [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + return generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, batched + ) + + +GENERATOR_FUNC_TABLE = { + 75: generate_sm75_tensor_op_1688, + 80: generate_sm80_tensor_op_16816, +} + +# TODO(masahi): A sensible way to pick reasonable default kernels +DEFAULT_KERNELS = { + 75: { + "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4", + "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4", + }, + 80: { + "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4", + "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4", + }, +} + + +class ProfilerEngine: + """Compile and run a given profiler executable.""" + + def __init__(self, cuda_arch, cutlass_path, binary_prefix): + self.cuda_arch = cuda_arch + self.binary_prefix = binary_prefix + self.cutlass = cutlass_path + self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format( + cutlass=cutlass_path + ) + self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format( + arch=cuda_arch + ) + self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing" + self.cmd = "nvcc {cflags} {src} -o {output}" + + def _compile(self, op): + os.makedirs(self.binary_prefix, exist_ok=True) + opath = os.path.join(self.binary_prefix, op["name"]) + if os.path.exists(opath): + return + fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu") + fi.write(op["src"]) + fi.close() + cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath) + os.system(cmd) + os.unlink(fi.name) + + def compile_all(self, ops, use_multiprocessing=False): + """Compile all profiler executables.""" + if use_multiprocessing: + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + pool.map(self._compile, ops) + else: + for op in ops: + self._compile(op) + + def evaluate(self, op, args): + """Run the profiler executable corresponding to op_name with args.""" + op_name = op["name"] + opath = os.path.join(self.binary_prefix, op_name) + if not os.path.exists(opath): + self._compile(op) + cmd = [opath] + if args is not None: + cmd.append(str(args[0])) + cmd.append(str(args[1])) + cmd.append(str(args[2])) + if len(args) > 3: + cmd.append(str(args[3])) + try: + sp = subprocess.run(cmd, capture_output=True, check=True) + rt = float(sp.stdout) + logger.info("%s, %f", op_name, rt) + except subprocess.CalledProcessError: + rt = -1 + return rt + + +class CutlassGemmProfiler(object): + """Profile all candidate kernels and select the best one.""" + + def __init__(self, sm, cutlass_path, binary_path): + assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm + self.engine = ProfilerEngine(sm, cutlass_path, binary_path) + self.sm = sm + self.cache = {} + + def check_align(self, op_name, M): + """Filter out kernels that cannot be supported.""" + aligns = re.findall(r"align[1|2|4|8]", op_name) + assert len(aligns) == 1 + align = int(aligns[0][-1]) + if M % align != 0: + return False + return True + + def get_default(self, out_dtype, batched=False): + """Return the default kernel for the requested architecture. + For now, the default kernel was picked arbitrary. + """ + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) + default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] + filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) + assert len(filtered) == 1 + return filtered[0] + + def profile( + self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False + ): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + if (M, N, K) in self.cache: + return self.cache[(M, N, K)] + + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) + ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) + + for op in ops: + op["runtime"] = -1 + + if profile_all: + self.engine.compile_all(ops, use_multiprocessing) + + for op in ops: + out = self.engine.evaluate(op, [M, N, K]) + op["runtime"] = out + if out > 0 and profile_all is False: + break + + valid_ops = filter(lambda op: op["runtime"] > 0, ops) + output = sorted(valid_ops, key=lambda i: i["runtime"]) + self.cache[(M, N, K)] = output[0] + return output[0] diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py new file mode 100644 index 000000000000..a3b90ff83d1f --- /dev/null +++ b/python/tvm/contrib/cutlass/library.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Various type definitions to help instantiate CUTLASS kernels.""" +import re +import enum +from enum import auto as enum_auto + + +class GeneratorTarget(enum.Enum): + Library = enum_auto() + + +class DataType(enum.Enum): + f16 = enum_auto() + f32 = enum_auto() + + +ShortDataTypeNames = { + DataType.f16: "h", + DataType.f32: "s", +} + + +DataTypeNames = { + DataType.f16: "f16", + DataType.f32: "f32", +} + +DataTypeTag = { + DataType.f16: "cutlass::half_t", + DataType.f32: "float", +} + +DataTypeSize = { + DataType.f16: 16, + DataType.f32: 32, +} + + +class MathOperation(enum.Enum): + multiply_add = enum_auto() + + +MathOperationTag = { + MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", +} + + +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + + +LayoutTag = { + LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", + LayoutType.RowMajor: "cutlass::layout::RowMajor", +} + + +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, +} + + +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: "n", + LayoutType.RowMajor: "t", +} + + +class OpcodeClass(enum.Enum): + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + + +OpcodeClassNames = { + OpcodeClass.Simt: "simt", + OpcodeClass.TensorOp: "tensorop", + OpcodeClass.WmmaTensorOp: "wmma_tensorop", +} + +OpcodeClassTag = { + OpcodeClass.Simt: "cutlass::arch::OpClassSimt", + OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", + OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", +} + + +class OperationKind(enum.Enum): + Gemm = enum_auto() + + +OperationKindNames = { + OperationKind.Gemm: "gemm", +} + + +class Target(enum.Enum): + library = enum_auto() + + +def substitute_template(template, values): + """Instantiate a kernel template using `values`.""" + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + + +class GemmKind(enum.Enum): + Gemm = enum_auto() + + +GemmKindNames = { + GemmKind.Gemm: "gemm", +} + + +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationRelu = enum_auto() + LinearCombinationBias = enum_auto() + LinearCombinationGelu = enum_auto() + + +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", + EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu", + EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination", + EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU", +} + + +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Batched = enum_auto() + + +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", + SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", + SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", + SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle", +} + + +class MathInstruction: + """Describe characteristics of a math instruction.""" + + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class, + math_operation=MathOperation.multiply_add, + ): + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + + +class TileDescription: + """Describe characteristics of a GEMM tile.""" + + def __init__( + self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute + ): + self.threadblock_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + return "%dx%d_%dx%d" % ( + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.stages, + ) + + +class TensorDescription: + def __init__(self, element, layout, alignment=1): + self.element = element + self.layout = layout + self.alignment = alignment diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index c26255fc5517..2b142dc75a05 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -70,6 +70,10 @@ def get_onnx_version(): return onnx.__version__ +def get_node_shape(node): + return tuple("Any" if isinstance(i, tvm.tir.Any) else int(i) for i in node.shape) + + def infer_type(node): """A method to infer the type of a relay expression.""" mod = tvm.IRModule.from_expr(node) @@ -521,7 +525,7 @@ def convert(cls, node_entry, model_container, node_dict): input_node = node_dict[node_entry["inputs"][0]] assert len(input_node) == 1, "input node can not be a Tuple" input_node = input_node[0] - shape = input_node["types"][0].concrete_shape + shape = get_node_shape(input_node["types"][0]) indices_or_sect = attrs["indices_or_section"] axis = attrs["axis"] @@ -1019,7 +1023,7 @@ def _add_input(self, node_entry, idx): node_type = node_entry["types"][0] dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] input = onnx.helper.make_tensor_value_info( - node_entry["name"], dtype, shape=node_type.concrete_shape + node_entry["name"], dtype, shape=get_node_shape(node_type) ) self._mc.add_inputs([input]) @@ -1030,7 +1034,7 @@ def _add_output(self, node_entries): for node_type, output_name in zip(node_entry["types"], node_entry["output_names"]): dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] output = onnx.helper.make_tensor_value_info( - output_name, dtype, shape=node_type.concrete_shape + output_name, dtype, shape=get_node_shape(node_type) ) self._mc.add_outputs([output]) diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py new file mode 100644 index 000000000000..720ac29cc6e2 --- /dev/null +++ b/python/tvm/contrib/torch/__init__.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wrong-import-position,redefined-builtin,invalid-name +"""Module container of Pytorch custom class""" +import os +import platform +import torch +from tvm._ffi import libinfo +from tvm.relay.frontend import pytorch + + +def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): + system = platform.system() + if system == "Darwin": + lib_file_name = lib_name + ".dylib" + elif system == "Windows": + lib_file_name = lib_name + ".dll" + else: + lib_file_name = lib_name + ".so" + lib_path = libinfo.find_lib_path()[0] + lib_dir = os.path.dirname(lib_path) + lib_file_path = os.path.join(lib_dir, lib_file_name) + torch.classes.load_library(lib_file_path) + + +_load_platform_specific_library() + +from . import module + +GraphModule = module.GraphModule +VMModule = module.VMModule +TraceTvmModule = module.TraceTvmModule + +from . import pytorch_tvm + +PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule +compile = pytorch_tvm.compile diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py new file mode 100644 index 000000000000..3da9c6f591ce --- /dev/null +++ b/python/tvm/contrib/torch/module.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Module container of PyTorch custom class""" +from typing import List +import torch + + +class GraphModule(torch.nn.Module): + r"""Module container of Pytorch class which wraps exported + TVM op implementation library to be called on Pytorch side""" + + @classmethod + def shape_repr(cls, input_shapes): + return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) + + def __init__(self, num_inputs, num_outputs, device=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.engine = None + + if device is not None: + self.to(device) + self.engine = torch.classes.tvm_dsoop.TvmGraphModule(num_inputs, num_outputs, self.device) + + def init(self, input_shapes, lib_path, graph_path, params_path): + r"""Load tvm module""" + self.engine.load_tvm_module(input_shapes, lib_path, graph_path, params_path) + + def forward(self, inputs: List[torch.Tensor]): + r"""Call tvm module to forward""" + return self.engine.forward(inputs) + + @property + def device(self): + r"""Get the device string""" + return str(self.dummy_param.device) + + def _apply(self, fn): + r"""Override to device function, manually move tvm module to desired device""" + super()._apply(fn) + if self.engine is not None: + self.engine.to(self.device) + return self + + +class VMModule(torch.nn.Module): + r"""Module container of Pytorch class which wraps exported + TVM op implementation library to be called on Pytorch side""" + + @classmethod + def shape_repr(cls, input_shapes): + return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) + + def __init__(self, num_inputs, num_outputs, device=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.empty(0)) + self.engine = None + + if device is not None: + self.to(device) + self.engine = torch.classes.tvm_dsoop.TvmVMModule(num_inputs, num_outputs, self.device) + + def init(self, input_shapes, lib_path, code_path): + r"""Load tvm module""" + self.engine.load_tvm_module(input_shapes, lib_path, code_path) + + def forward(self, inputs: List[torch.Tensor]): + r"""Call tvm module to forward""" + return self.engine.forward(inputs) + + @property + def device(self): + r"""Get the device string""" + return str(self.dummy_param.device) + + def _apply(self, fn): + r"""Override to device function, manually move tvm module to desired device""" + super()._apply(fn) + if self.engine is not None: + self.engine.to(self.device) + return self + + +class TraceTvmModule(torch.nn.Module): + r"""Wrapper for trace GraphModule + + GraphModule and VMModule only supports List[Tensor] inputs and cannot be traced. + This is a wrapper class for trace GraphModule or VMModule in order to support + arbitrary number of inputs + + Example: + import tvm.contrib.torch + tvm_module = tvm.contrib.torch.GraphModule(1, 1, 'cuda:0') + tvm_module.init(input_shapes, lib_path, graph_path, params_path) + + trace_wrapper = tvm.contrib.torch.TraceGraphModule(torch.jit.script(tvm_module)) + traced = torch.jit.trace(trace_wrapper, example_inputs) + """ + + def __init__(self, tvm_module): + super().__init__() + self.tvm_module = tvm_module + + def forward(self, *inputs): + outputs = self.tvm_module(inputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py new file mode 100644 index 000000000000..1e50c98ab883 --- /dev/null +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""`compile` api that convert torch module to torch tvm module""" +import os +import tvm +import tvm.testing +from tvm import relay, autotvm +from tvm.runtime import load_module +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_executor +from tvm.contrib.debugger import debug_executor +from . import GraphModule + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=1000, + early_stopping=None, + log_filename="tuning.log", + use_transfer_learning=True, +): + """Tune tasks and generate tuning log to file""" + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = f"[Task {i + 1:2d}/{len(tasks):2d}] " + + # create tuner + if tuner in ("xgb", "sgb-rank"): + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=100) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + # pick best records to a cache file + if not os.path.exists(log_filename): + with open(log_filename, "w", encoding="utf-8"): + pass + if os.path.exists(tmp_log_file): + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + +def get_tuning_opt(log_file="tuning.log", n_trial=200): + """Returns tuning options""" + tuning_opt = { + "log_filename": log_file, + "tuner": "random", + "n_trial": n_trial, + "early_stopping": 60, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(timeout=10), + runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150), + ), + } + return tuning_opt + + +TVM_ASSETS = ["mod.so", "graph.json", "params"] + + +class PyTorchTVMModule: + """Helper class for compiling pytorch module to tvm module""" + + def __init__(self, target="cuda", device=tvm.cuda(0)) -> None: + self.script_module = None + self.input_infos = None + self.default_dtype = "float32" + self.mod = None + self.params = None + self.tasks = None + self.target = target + self.dev = device + self.log_file = None + self.tvm_module = None + self.tvm_graph = None + self.tvm_lib = None + self.tvm_params = None + + def from_pytorch(self, script_module, input_infos, default_dtype="float32"): + self.script_module = script_module + self.input_infos = input_infos + self.default_dtype = default_dtype + self.mod, self.params = relay.frontend.from_pytorch( + script_module, input_infos, default_dtype=default_dtype + ) + + def tune_tvm(self, log_file="tuning.log", n_trial=200): + self.tasks = autotvm.task.extract_from_program( + self.mod["main"], + target=self.target, + params=self.params, + ) + self.log_file = log_file + tuning_opt = get_tuning_opt(log_file, n_trial) + tune_tasks(self.tasks, **tuning_opt) + + def build_tvm(self, export_dir, debug_runtime=False): + tvm_mod = self._build_tvm(debug_runtime) + self._export_tvm(export_dir) + return tvm_mod + + def _build_tvm(self, debug_runtime=False): + # compile kernels with history best records + with autotvm.apply_history_best(self.log_file): + with tvm.transform.PassContext(opt_level=3): + self.tvm_graph, self.tvm_lib, self.tvm_params = relay.build( + self.mod, target=self.target, params=self.params + ) + + if not debug_runtime: + self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + else: + self.tvm_module = debug_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + self.tvm_module.set_input(**self.tvm_params) + return self.tvm_module + + def _export_tvm(self, export_dir): + if not os.path.isdir(export_dir): + os.makedirs(export_dir) + self.export_dir = export_dir + self.tvm_lib.export_library(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "w", encoding="utf8") as fout: + fout.write(self.tvm_graph) + with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout: + fout.write(relay.save_param_dict(self.tvm_params)) + + def load_tvm(self, export_dir): + """Load tvm module from export directory""" + self.export_dir = export_dir + self.tvm_lib = load_module(os.path.join(export_dir, TVM_ASSETS[0])) + with open(os.path.join(export_dir, TVM_ASSETS[1]), "r", encoding="utf8") as f: + self.tvm_graph = f.read() + with open(os.path.join(export_dir, TVM_ASSETS[2]), "rb") as f: + self.tvm_params = relay.load_param_dict(f.read()) + + self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, device=self.dev) + self.tvm_module.set_input(**self.tvm_params) + return self.tvm_module + + def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): + """Build pytorch module containing TVM Graph Module""" + assert self.export_dir, "you must build_tvm or load_tvm before" + input_infos = input_infos or self.input_infos + assert input_infos + assert len(input_infos) == num_inputs + assets = [os.path.join(self.export_dir, i) for i in TVM_ASSETS] + input_shapes = [i[1] for i in input_infos] + + def _tvm_dev_to_pt_dev(device): + """convert tvm device to pytorch device string""" + if tvm.runtime.Device.MASK2STR[device.device_type] == "cpu": + return "cpu" + if tvm.runtime.Device.MASK2STR[device.device_type] == "cuda": + return f"cuda:{device.device_id}" + raise ValueError(f"unsupported device for pt graph module: {device}") + + mod = GraphModule(num_inputs=num_inputs, num_outputs=num_outputs).to( + _tvm_dev_to_pt_dev(self.dev) + ) + mod.init(input_shapes, *assets) + return mod + + +def compile(script_module, option): + """ + example: + option = { + "input_infos": [ + ("x", (1, 3, 244, 244)), + ], + "default_dtype": "float16", + "export_dir": "pytorch_compiled", + "num_outputs": 1, + "tuning_n_trials": 20, # set zero to skip tuning + "tuning_log_file": "tuning.log", + "target": "llvm", + "device": tvm.cpu(), + } + script_module = torch.jit.script(model) + pytorch_tvm_module = compile(script_module, option) + pytorch_tvm_module("model_tvm.pt") + """ + input_infos = option["input_infos"] + default_dtype = option.get("default_dtype", "float32") + export_dir = option.get("export_dir", "pytorch_compiled") + tuning_log_file = option.get("tuning_log_file", "tuning.log") + tuning_n_trials = option.get("tuning_n_trials", 20) + num_outputs = option.get("num_outputs", 1) + target = option.get("target", "cuda") + device = option.get("device", tvm.cuda(0)) + + mod = PyTorchTVMModule(target=target, device=device) + print("Converting...") + + mod.log_file = tuning_log_file + mod.from_pytorch(script_module, input_infos, default_dtype) + + if tuning_n_trials > 0: + print("Tuning...") + mod.tune_tvm(log_file=tuning_log_file, n_trial=tuning_n_trials) + + print("Building...") + mod.build_tvm(export_dir) + pytorch_mod = mod.build_pytorch_module(num_inputs=len(input_infos), num_outputs=num_outputs) + return pytorch_mod diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 1ee24cf69d44..65b0c3dbc0aa 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -27,6 +27,7 @@ import tvm +from tvm.driver import tvmc from tvm import relay from tvm import transform from tvm._ffi import registry @@ -206,6 +207,7 @@ def parse_target(target): a key-value for all options passed via CLI; 'raw', containing the plain string for this codegen """ + codegen_names = tvmc.composite_target.get_codegen_names() codegens = [] tvm_target_kinds = tvm.target.Target.list_kinds() @@ -232,7 +234,7 @@ def parse_target(target): for codegen_def in split_codegens: # the first is expected to be the name name = codegen_def[0] - is_tvm_target = name in tvm_target_kinds + is_tvm_target = name in tvm_target_kinds and name not in codegen_names raw_target = " ".join(codegen_def) all_opts = codegen_def[1:] if len(codegen_def) > 1 else [] opts = {} diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index ba7862378557..0c04d2b7248f 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -23,7 +23,8 @@ import tvm.contrib.target.vitis_ai # pylint: disable=unused-import from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib -from tvm.relay.op.contrib.ethosn import partition_for_ethosn +from tvm.relay.op.contrib.ethosn import partition_for_ethosn77 +from tvm.relay.op.contrib.ethosn import partition_for_ethosn78 from tvm.relay.op.contrib.cmsisnn import partition_for_cmsisnn from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.op.contrib.bnns import partition_for_bnns @@ -57,7 +58,11 @@ }, "ethos-n77": { "config_key": "relay.ext.ethos-n.options", - "pass_pipeline": partition_for_ethosn, + "pass_pipeline": partition_for_ethosn77, + }, + "ethos-n78": { + "config_key": "relay.ext.ethos-n.options", + "pass_pipeline": partition_for_ethosn78, }, "ethos-u": { "config_key": "relay.ext.ethosu.options", diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 21d3d59fb013..13ab3dd170c3 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -262,7 +262,9 @@ def load(self, path, shape_dict=None, **kwargs): input_shapes = list(shape_dict.items()) logger.debug("parse Torch model and convert into Relay computation graph") - return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs) + return relay.frontend.from_pytorch( + traced_model, input_shapes, keep_quantized_weight=True, **kwargs + ) class PaddleFrontend(Frontend): diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index 7a078b8be087..a551293b26a5 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -18,6 +18,7 @@ This file contains functions for processing target inputs for the TVMC CLI """ +from tvm.driver import tvmc from tvm.target import Target # We can't tell the type inside an Array but all current options are strings so @@ -27,6 +28,11 @@ INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} +def _valid_target_kinds(): + codegen_names = tvmc.composite_target.get_codegen_names() + return filter(lambda target: target not in codegen_names, Target.list_kinds()) + + def _generate_target_kind_args(parser, kind): target_group = parser.add_argument_group(f"target {kind.name}") for target_option, target_type in kind.options.items(): @@ -45,8 +51,7 @@ def generate_target_args(parser): help="compilation target as plain string, inline JSON or path to a JSON file", required=True, ) - target_kinds = Target.list_kinds() - for target_kind in target_kinds: + for target_kind in _valid_target_kinds(): target = Target(target_kind) _generate_target_kind_args(parser, target.kind) @@ -55,7 +60,7 @@ def _reconstruct_target_kind_args(args, kind): kind_options = {} for target_option, target_type in kind.options.items(): if target_type in INTERNAL_TO_NATIVE_TYPE: - var_name = f"target_{kind.name}_{target_option.replace('-', '_')}" + var_name = f"target_{kind.name.replace('-', '_')}_{target_option.replace('-', '_')}" option_value = getattr(args, var_name) if option_value is not None: kind_options[target_option] = getattr(args, var_name) @@ -64,9 +69,8 @@ def _reconstruct_target_kind_args(args, kind): def reconstruct_target_args(args): """Reconstructs the target options from the arguments""" - target_kinds = Target.list_kinds() reconstructed = {} - for target_kind in target_kinds: + for target_kind in _valid_target_kinds(): target = Target(target_kind) kind_options = _reconstruct_target_kind_args(args, target.kind) if kind_options: diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 1948a6787eac..70c482e06125 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -29,52 +29,140 @@ class PassInstrument(tvm.runtime.Object): """A pass instrument implementation. - Users don't need to interact with this class directly. - Instead, a `PassInstrument` instance should be created through - :py:func:`pass_instrument` + To use, a user class can either subclass from PassInstrument + directly, or can apply the :py:func:`pass_instrument` wrapper. In + either case, the `enter_pass_ctx`, `exit_pass_ctx`, `should_run`, + `run_before_pass`, and `run_after_pass` methods can be defined to + adjust the instrument's behavior. See the no-op implementations + in this class definition for more information on each. + """ + def __init__(self): + # initialize handle in case pi_cls creation failed. + self.handle = None + cls = type(self) + + # If the child class declared the method, then use it. + # Otherwise, pass None to avoid a C++ -> Python round trip for + # a no-op. + def get_child_method(name): + if getattr(cls, name) is getattr(PassInstrument, name): + return None + + return getattr(self, name) + + # Create runtime pass instrument object. + # reister instance's enter_pass_ctx,exit_pass_ctx, should_run, run_before_pass and + # run_after_pass methods to it if present. + self.__init_handle_by_constructor__( + _ffi_instrument_api.PassInstrument, + cls.__name__, + get_child_method("enter_pass_ctx"), + get_child_method("exit_pass_ctx"), + get_child_method("should_run"), + get_child_method("run_before_pass"), + get_child_method("run_after_pass"), + ) + + def enter_pass_ctx(self): + """Called when entering the instrumented context. + + Returns + ------- + None + """ + + def exit_pass_ctx(self): + """Called when exiting the instrumented context. + + Returns + ------- + None + """ + + def should_run(self, mod, info): + """Determine whether to run the pass or not. + + Called once for each pass that is run while the instrumented + context is active. + + Parameters + ---------- + mod : tvm.ir.module.IRModule + + The module on which an optimization pass is being run. + + info : tvm.transform.PassInfo + + The pass information. + + Returns + ------- + should_run : bool + + True to run the pass, or False to skip the pass. + """ + + def run_before_pass(self, mod, info): + """Instrument before the pass runs. + + Called once for each pass that is run while the instrumented + context is active. + + Parameters + ---------- + mod : tvm.ir.module.IRModule + + The module on which an optimization pass is being run. + + info : tvm.transform.PassInfo + + The pass information. + + Returns + ------- + None + """ + + def run_after_pass(self, mod, info): + """Instrument after the pass runs. + + Called once for each pass that is run while the instrumented + context is active. + + Parameters + ---------- + mod : tvm.ir.module.IRModule + + The module on which an optimization pass is being run. + + info : tvm.transform.PassInfo + + The pass information. + + Returns + ------- + None + """ + def _wrap_class_pass_instrument(pi_cls): """Wrap a python class as pass instrument""" - class PyPassInstrument(PassInstrument): + # No additional wrapping needed if the user class already + # inherits. + if issubclass(pi_cls, PassInstrument): + return pi_cls + + class PyPassInstrument(pi_cls, PassInstrument): """Internal wrapper class to create a class instance.""" def __init__(self, *args, **kwargs): # initialize handle in case pi_cls creation failed. self.handle = None - inst = pi_cls(*args, **kwargs) - - # check method declartion within class, if found, wrap it. - def create_method(method): - if hasattr(inst, method) and inspect.ismethod(getattr(inst, method)): - - def func(*args): - return getattr(inst, method)(*args) - - func.__name__ = "_" + method - return func - return None - - # create runtime pass instrument object - # reister instance's enter_pass_ctx,exit_pass_ctx, should_run, run_before_pass and - # run_after_pass methods to it if present. - self.__init_handle_by_constructor__( - _ffi_instrument_api.PassInstrument, - pi_cls.__name__, - create_method("enter_pass_ctx"), - create_method("exit_pass_ctx"), - create_method("should_run"), - create_method("run_before_pass"), - create_method("run_after_pass"), - ) - - self._inst = inst - - def __getattr__(self, name): - # fall back to instance attribute if there is not any - return self._inst.__getattribute__(name) + pi_cls.__init__(self, *args, **kwargs) + PassInstrument.__init__(self) functools.update_wrapper(PyPassInstrument.__init__, pi_cls.__init__) PyPassInstrument.__name__ = pi_cls.__name__ diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 11ef823ab481..1a705b999b74 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -256,7 +256,7 @@ def __str__(self): def __repr__(self): return self.astext() - def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str: + def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript Parameters diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 2e280ef20ac3..47b3dda5a36e 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -21,4 +21,5 @@ from . import runner from . import space_generator from . import search_strategy +from . import integration from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index ed81f4c0d3f9..381051e85f55 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -23,6 +23,7 @@ from tvm.target import Target from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.BuilderInput") @@ -119,6 +120,7 @@ class PyBuilder(Builder): def __init__(self): """Constructor.""" + @check_override(self.__class__, Builder) def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: return self.build(build_inputs) @@ -126,6 +128,3 @@ def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]: _ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member f_build, ) - - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 3d05441fe22b..fd746e640c76 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -25,7 +25,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo -from ..utils import _json_de_tvm +from ..utils import _json_de_tvm, check_override @register_object("meta_schedule.Workload") @@ -207,15 +207,19 @@ class PyDatabase(Database): def __init__(self): """Constructor.""" + @check_override(self.__class__, Database) def f_commit_workload(mod: IRModule) -> Workload: return self.commit_workload(mod) + @check_override(self.__class__, Database) def f_commit_tuning_record(record: TuningRecord) -> None: self.commit_tuning_record(record) + @check_override(self.__class__, Database) def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]: return self.get_top_k(workload, top_k) + @check_override(self.__class__, Database, func_name="__len__") def f_size() -> int: return len(self) @@ -226,15 +230,3 @@ def f_size() -> int: f_get_top_k, f_size, ) - - def commit_workload(self, mod: IRModule) -> Workload: - raise NotImplementedError - - def commit_tuning_record(self, record: TuningRecord) -> None: - raise NotImplementedError - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - raise NotImplementedError - - def __len__(self) -> int: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py new file mode 100644 index 000000000000..47003c6faa25 --- /dev/null +++ b/python/tvm/meta_schedule/integration.py @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta schedule integration with high-level IR""" +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Union + +from tvm._ffi import register_object +from tvm.ir import IRModule, transform +from tvm.relay import Any, Function as RelayFunc, vm +from tvm.runtime import NDArray, Object +from tvm.target import Target +from tvm.tir import PrimFunc + +from . import _ffi_api + + +@register_object("meta_schedule.ExtractedTask") +class ExtractedTask(Object): + """A tuning task extracted from the high-level IR + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : List[IRModule] + A list of low-level IRs that the high-level IR could potentially dispatch to + """ + + task_name: str + mod: IRModule + dispatched: List[IRModule] + + def __init__( + self, + task_name: str, + mod: IRModule, + dispatched: List[IRModule], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ExtractedTask, # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + +@register_object("meta_schedule.MetaScheduleContext") +class MetaScheduleContext(Object): + """A context manager interface for the integration""" + + def query( + self, + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration + + Parameters + ---------- + task_name : str + The name of the task extracted + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.MetaScheduleContextQuery( # type: ignore # pylint: disable=no-member + self, + task_name, + mod, + dispatched, + ) + + @staticmethod + def current() -> Optional["MetaScheduleContext"]: + """The context manager in the current scope + + Returns + ------- + ctx : Optional[MetaScheduleContext] + The MetaScheduleContext in the current scope. + NullOpt if it's currently not under any MetaScheduleContext. + """ + return _ffi_api.MetaScheduleContextCurrent() # type: ignore # pylint: disable=no-member + + @staticmethod + def query_inside_with_scope( + task_name: str, + mod: IRModule, + dispatched: Optional[List[IRModule]], + ) -> Union[IRModule, RelayFunc, PrimFunc, None]: + """The entry point of the integration workflow. The compilation process of the high-level + IR should call this method for task extraction and for feedback hints + + Basically, this method is equivalent to: + + .. code-block:: python + + def query_inside_with_scope(task_name, mod, dispatched): + ctx = MetaScheduleContext.current() + assert ctx is not None + ctx.query(task_name, mod, dispatched) + + Parameters + ---------- + task_name : str + The name of the task + mod : IRModule + The high-level IR + dispatched : Optional[List[IRModule]] + A list of low-level IRs that the high-level IR could potentially dispatch to + + Returns + ------- + result : Union[IRModule, RelayFunc, PrimFunc, None] + There are different types of the output: + 1) NullOpt if there is no feedback hint; + 2) tir::PrimFunc if `mod` should be lowered to a PrimFunc; + 3) relay::Function if `mod` should be dispatched to BYOC workflow; + 4) IRModule for unified dispatch + """ + return _ffi_api.MetaScheduleContextQueryInsideWithScope( # type: ignore # pylint: disable=no-member + task_name, + mod, + dispatched, + ) + + def __enter__(self) -> "MetaScheduleContext": + """Entering the scope of the context manager""" + _ffi_api.MetaScheduleContextEnterScope(self) # type: ignore # pylint: disable=no-member + return self + + def __exit__(self, ptype, value, trace) -> None: + """Exiting the scope of the context manager""" + _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TaskExtraction") +class TaskExtraction(MetaScheduleContext): + """An integration context for task extraction""" + + tasks: List[ExtractedTask] + """The extracted tasks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.ApplyHistoryBest") +class ApplyHistoryBest(MetaScheduleContext): + pass + + +def extract_task( + mod: Union[IRModule, RelayFunc], + target: Target, + params: Optional[Dict[str, NDArray]] = None, + *, + opt_level: int = 3, + pass_config: Dict[str, Any] = { + "relay.backend.use_meta_schedule": True, + }, + disabled_pass: List[str] = [], +) -> List[ExtractedTask]: + """Extract tuning tasks from a relay program. + + Parameters + ---------- + mod : Union[tvm.IRModule, tvm.relay.Function] + The module or function to tune + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + opt_level : int + The optimization level of the compiler + pass_config : Dict[str, Any] + The pass config of the compiler + disabled_pass : List[str] + The list of disabled passes of the compiler + + Returns + ------- + tasks: List[ExtractedTask] + The tasks extracted from this network + """ + + @contextmanager + def _autotvm_silencer(): + from tvm import autotvm # pylint: disable=import-outside-toplevel + + silent = autotvm.GLOBAL_SCOPE.silent + autotvm.GLOBAL_SCOPE.silent = True + try: + yield + finally: + autotvm.GLOBAL_SCOPE.silent = silent + + def _thread_run(func: Callable[[], None]) -> None: + import threading # pylint: disable=import-outside-toplevel + + thread = threading.Thread(target=func) + thread.start() + thread.join() + + env = TaskExtraction() + if isinstance(mod, RelayFunc): + mod = IRModule.from_expr(mod) + if not isinstance(target, Target): + target = Target(target) + + def _func(): + with env, _autotvm_silencer(), transform.PassContext( + config=pass_config, + disabled_pass=disabled_pass, + opt_level=opt_level, + ): + compiler = vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target) + + _thread_run(_func) + return env.tasks diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 9f7be8ea4af4..71a557dca3a3 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -22,6 +22,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo +from ..utils import check_override @register_object("meta_schedule.RunnerInput") @@ -158,6 +159,7 @@ class PyRunner(Runner): def __init__(self) -> None: """Constructor""" + @check_override(self.__class__, Runner) def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: return self.run(runner_inputs) @@ -165,6 +167,3 @@ def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member f_run, ) - - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index d270ea61f6dc..6cee09edd4fc 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Search Strategy""" - +""" +Meta Schedule search strategy that generates the measure +candidates for measurement. +""" from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object @@ -25,6 +27,7 @@ from .. import _ffi_api from ..arg_info import ArgInfo from ..runner import RunnerResult +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -126,18 +129,23 @@ class PySearchStrategy(SearchStrategy): def __init__(self): """Constructor.""" + @check_override(self.__class__, SearchStrategy) def f_initialize_with_tune_context(context: "TuneContext") -> None: self.initialize_with_tune_context(context) + @check_override(self.__class__, SearchStrategy) def f_pre_tuning(design_spaces: List[Schedule]) -> None: self.pre_tuning(design_spaces) + @check_override(self.__class__, SearchStrategy) def f_post_tuning() -> None: self.post_tuning() + @check_override(self.__class__, SearchStrategy) def f_generate_measure_candidates() -> List[MeasureCandidate]: return self.generate_measure_candidates() + @check_override(self.__class__, SearchStrategy) def f_notify_runner_results(results: List["RunnerResult"]) -> None: self.notify_runner_results(results) @@ -149,18 +157,3 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None: f_generate_measure_candidates, f_notify_runner_results, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def pre_tuning(self, design_spaces: List[Schedule]) -> None: - raise NotImplementedError - - def post_tuning(self) -> None: - raise NotImplementedError - - def generate_measure_candidates(self) -> List[MeasureCandidate]: - raise NotImplementedError - - def notify_runner_results(self, results: List["RunnerResult"]) -> None: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 798753d91345..e37fd14ba440 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,7 +18,6 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from typing import TYPE_CHECKING, List from tvm._ffi import register_object @@ -27,6 +26,7 @@ from tvm.tir.schedule import Schedule from .. import _ffi_api +from ..utils import check_override if TYPE_CHECKING: from ..tune_context import TuneContext @@ -74,9 +74,11 @@ class PySpaceGenerator(SpaceGenerator): def __init__(self): """Constructor.""" + @check_override(self.__class__, SpaceGenerator) def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: self.initialize_with_tune_context(tune_context) + @check_override(self.__class__, SpaceGenerator) def f_generate_design_space(mod: IRModule) -> List[Schedule]: return self.generate_design_space(mod) @@ -85,9 +87,3 @@ def f_generate_design_space(mod: IRModule) -> List[Schedule]: f_initialize_with_tune_context, f_generate_design_space, ) - - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: - raise NotImplementedError - - def generate_design_space(self, mod: IRModule) -> List[Schedule]: - raise NotImplementedError diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index f1e21ad3ddfe..aeea154cfe02 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,20 +15,65 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" + +from typing import List + from tvm._ffi import register_object from tvm.runtime import Object +from ..runner import Runner +from ..builder import Builder +from ..database import Database +from ..tune_context import TuneContext from .. import _ffi_api +from ..utils import check_override @register_object("meta_schedule.TaskScheduler") class TaskScheduler(Object): - """The abstract task scheduler interface.""" + """The abstract task scheduler interface. + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + """ + + tasks: List[TuneContext] + builder: Builder + runner: Runner + database: Database def tune(self) -> None: """Auto-tuning.""" _ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member + def next_task_id(self) -> int: + """Fetch the next task id. + + Returns + ------- + int + The next task id. + """ + return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member + + def _initialize_task(self, task_id: int) -> None: + """Initialize modules of the given task. + + Parameters + ---------- + task_id : int + The task id to be initialized. + """ + _ffi_api.TaskSchedulerInitializeTask(self, task_id) # type: ignore # pylint: disable=no-member + def _set_task_stopped(self, task_id: int) -> None: """Set specific task to be stopped. @@ -64,59 +109,74 @@ def _join_running_task(self, task_id: int) -> None: """ _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member - def _next_task_id(self) -> int: - """Fetch the next task id. - - Returns - ------- - int - The next task id. - """ - return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member - @register_object("meta_schedule.PyTaskScheduler") class PyTaskScheduler(TaskScheduler): """An abstract task scheduler with customized methods on the python-side.""" - def __init__(self): - """Constructor.""" + def __init__( + self, + tasks: List[TuneContext], + builder: Builder, + runner: Runner, + database: Database, + ): + """Constructor. + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + """ + + @check_override(self.__class__, TaskScheduler, required=False) def f_tune() -> None: self.tune() + @check_override(self.__class__, TaskScheduler) + def f_next_task_id() -> int: + return self.next_task_id() + + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_initialize_task" + ) + def f_initialize_task(task_id: int) -> None: + self._initialize_task(task_id) + + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_set_task_stopped" + ) def f_set_task_stopped(task_id: int) -> None: self._set_task_stopped(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_is_task_running" + ) def f_is_task_running(task_id: int) -> bool: return self._is_task_running(task_id) + @check_override( + PyTaskScheduler, TaskScheduler, required=False, func_name="_join_running_task" + ) def f_join_running_task(task_id: int) -> None: self._join_running_task(task_id) - def f_next_task_id() -> int: - return self._next_task_id() - self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member + tasks, + builder, + runner, + database, f_tune, + f_initialize_task, f_set_task_stopped, f_is_task_running, f_join_running_task, f_next_task_id, ) - - def tune(self) -> None: - raise NotImplementedError() - - def _set_task_stopped(self, task_id: int) -> None: - _ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member - - def _is_task_running(self, task_id: int) -> bool: - return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member - - def _join_running_task(self, task_id: int) -> None: - _ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member - - def _next_task_id(self) -> int: - return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py new file mode 100644 index 000000000000..7e516a510f66 --- /dev/null +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +from .local_rpc import LocalRPC +from .relay_workload import get_network diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing/local_rpc.py similarity index 97% rename from python/tvm/meta_schedule/testing.py rename to python/tvm/meta_schedule/testing/local_rpc.py index b286e3b18a93..cd1221124cc9 100644 --- a/python/tvm/meta_schedule/testing.py +++ b/python/tvm/meta_schedule/testing/local_rpc.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Testing utilities in meta schedule""" +"""RPC tracker and server running locally""" from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py new file mode 100644 index 000000000000..1eb9950f7fc7 --- /dev/null +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in Relay IR""" +from typing import Dict, Tuple + +import tvm.relay.testing # pylint: disable=unused-import +from tvm import relay +from tvm.ir import IRModule +from tvm.runtime import NDArray + + +def get_network( + name: str, + batch_size: int, + layout: str = "NHWC", + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: + """Get the symbol definition and random weight of a network""" + # meta-schedule prefers NHWC layout + if layout == "NHWC": + image_shape = (224, 224, 3) + elif layout == "NCHW": + image_shape = (3, 224, 224) + else: + raise ValueError("Invalid layout: " + layout) + + input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape + output_shape: Tuple[int, int] = (batch_size, 1000) + + if name.startswith("resnet-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name.startswith("resnet3d-"): + n_layer = int(name.split("-")[1]) + mod, params = relay.testing.resnet.get_workload( + num_layers=n_layer, + batch_size=batch_size, + layout=layout, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "mobilenet": + mod, params = relay.testing.mobilenet.get_workload( + batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape + ) + elif name == "squeezenet_v1.1": + assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" + mod, params = relay.testing.squeezenet.get_workload( + version="1.1", + batch_size=batch_size, + dtype=dtype, + image_shape=image_shape, + ) + elif name == "inception_v3": + input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == "mxnet": + from mxnet.gluon.model_zoo.vision import get_model # type: ignore # pylint: disable=import-outside-toplevel + + assert layout == "NCHW" + block = get_model("resnet50_v1", pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function( + net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs + ) + mod = IRModule.from_expr(net) + return mod, params, input_shape, output_shape diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index c79137d55dda..a9ef514543f8 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -205,3 +205,43 @@ def structural_hash(mod: IRModule) -> str: # but ffi can't handle unsigned integers properly so it's parsed into a negative number shash += 1 << 64 return str(shash) + + +def check_override( + derived_class: Any, base_class: Any, required: bool = True, func_name: str = None +) -> Callable: + """Check if the derived class has overridden the base class's method. + + Parameters + ---------- + derived_class : Any + The derived class. + base_class : Any + The base class of derived class. + required : bool + If the method override is required. + func_name : str + Name of the method. Default value None, which would be set to substring of the given + function, e.g. `f_generate`->`generate`. + + Returns + ------- + func : Callable + Raise NotImplementedError if the function is required and not overridden. If the + function is not overridden return None, other return the overridden function. + """ + + def inner(func: Callable): + + if func_name is None: + method = func.__name__[2:] + else: + method = func_name + + if getattr(derived_class, method) is getattr(base_class, method): + if required: + raise NotImplementedError(f"{derived_class}'s {method} method is not implemented!") + return None + return func + + return inner diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 2aea9d3fd61d..3b323f4227a2 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -15,11 +15,15 @@ # specific language governing permissions and limitations # under the License. """MicroTVM module for bare-metal backends""" - from .build import autotvm_build_func from .build import AutoTvmModuleLoader from .build import get_standalone_crt_dir -from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError +from .build import get_microtvm_template_projects + +from .model_library_format import ( + export_model_library_format, + UnsupportedInModelLibraryFormatError, +) from .project import generate_project, GeneratedProject, TemplateProject from .session import ( create_local_graph_executor, diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 9e278081933c..795a61edcbb3 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -22,6 +22,7 @@ import os import pathlib import contextlib +import enum from typing import Union from .._ffi import libinfo @@ -34,10 +35,24 @@ STANDALONE_CRT_DIR = None +class MicroTVMTemplateProject(enum.Enum): + ZEPHYR = "zephyr" + ARDUINO = "arduino" + CRT = "crt" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + class CrtNotFoundError(Exception): """Raised when the standalone CRT dirtree cannot be found.""" +class MicroTVMTemplateProjectNotFoundError(Exception): + """Raised when the microTVM template project dirtree cannot be found.""" + + def get_standalone_crt_dir() -> str: """Find the standalone_crt directory. @@ -64,6 +79,37 @@ def get_standalone_crt_dir() -> str: return STANDALONE_CRT_DIR +def get_microtvm_template_projects(platform: str) -> str: + """Find microTVM template project directory for specific platform. + + Parameters + ---------- + platform : str + Platform type which should be defined in MicroTVMTemplateProject. + + Returns + ------- + str : + Path to template project directory for platform. + """ + if platform not in MicroTVMTemplateProject.list(): + raise ValueError(f"platform {platform} is not supported.") + + if platform == MicroTVMTemplateProject.CRT.value: + return os.path.join(get_standalone_crt_dir(), "template", "host") + + microtvm_template_projects = None + for path in libinfo.find_lib_path(): + template_path = os.path.join(os.path.dirname(path), "microtvm_template_projects") + if os.path.isdir(template_path): + microtvm_template_projects = template_path + break + else: + raise MicroTVMTemplateProjectNotFoundError() + + return os.path.join(microtvm_template_projects, platform) + + class AutoTvmModuleLoader: """MicroTVM AutoTVM Module Loader diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index d1a36ac79d64..a5e54aa816a3 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -184,7 +184,7 @@ def generate_project_from_mlf( mlf_path : pathlib.Path or str Path to the Model Library Format archive that will be used when creating - the new project. + the new project. The archive file will be copied to project_dir. options : dict Project API options given to the microTVM API server for the specified platform. diff --git a/python/tvm/micro/testing.py b/python/tvm/micro/testing.py index 124f66e021a3..81e29a92a86a 100644 --- a/python/tvm/micro/testing.py +++ b/python/tvm/micro/testing.py @@ -19,8 +19,16 @@ import pathlib import json +import logging +import tarfile +import time from typing import Union +from tvm.micro.project_api.server import IoTimeoutError + +# Timeout in seconds for AOT transport. +TIMEOUT_SEC = 10 + def check_tune_log(log_path: Union[pathlib.Path, str]): """Read the tuning log and check each result.""" @@ -31,3 +39,47 @@ def check_tune_log(log_path: Union[pathlib.Path, str]): if len(line) > 0: tune_result = json.loads(line) assert tune_result["result"][0][0] < 1000000000.0 + + +def aot_transport_init_wait(transport): + """Send init message to microTVM device until it receives wakeup sequence.""" + while True: + try: + aot_transport_find_message(transport, "wakeup", timeout_sec=TIMEOUT_SEC) + break + except IoTimeoutError: + transport.write(b"init%", timeout_sec=TIMEOUT_SEC) + + +def aot_transport_find_message(transport, expression: str, timeout_sec: int) -> str: + """Read transport message until it finds the expression.""" + timeout = timeout_sec + start_time = time.monotonic() + while True: + data = _read_line(transport, timeout) + logging.debug("new line: %s", data) + if expression in data: + return data + timeout = max(0, timeout_sec - (time.monotonic() - start_time)) + + +def _read_line(transport, timeout_sec: int) -> str: + data = bytearray() + while True: + new_data = transport.read(1, timeout_sec=timeout_sec) + logging.debug("read data: %s", new_data) + for item in new_data: + data.append(item) + if str(chr(item)) == "\n": + return data.decode(encoding="utf-8") + + +def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[pathlib.Path, str]) -> int: + """Extract an MLF archive file and read workspace size from metadata file.""" + + with tarfile.open(mlf_tar_path, "r:*") as tar_file: + tar_members = [ti.name for ti in tar_file.getmembers()] + assert "./metadata.json" in tar_members + with tar_file.extractfile("./metadata.json") as f: + metadata = json.load(f) + return metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"] diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index 4fc2b63748db..b6a402b0f30f 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -15,4 +15,6 @@ # specific language governing permissions and limitations # under the License. """Backend codegen modules for relay.""" -from . import compile_engine +from . import te_compiler +from .executor import Executor +from .runtime import Runtime diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index b970aec62c6f..d0d04cebaefe 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -16,7 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" -from typing import List +from typing import List, Type + import numpy as np # type: ignore import tvm # type: ignore @@ -26,6 +27,7 @@ from tvm.relay.dataflow_pattern import wildcard from tvm.relay.dataflow_pattern import is_op from tvm.relay.dataflow_pattern import rewrite +from tvm.relay.dataflow_pattern import CallPattern from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api @@ -121,7 +123,7 @@ def __call__(self, *args, **kwargs): pass -class EthosUConv2DRewriter(DFPatternCallback): +class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" def __init__(self): @@ -193,14 +195,14 @@ def callback( @ir.transform.module_pass(opt_level=1) -class LegalizeEthosUConv2D: - """This is the pass that wraps the EthosUConv2DRewriter""" +class LegalizeConv2D: + """This is the pass that wraps the Conv2DRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): - func = rewrite(EthosUConv2DRewriter(), func) + func = rewrite(Conv2DRewriter(), func) mod.update_func(global_var, func) return mod @@ -208,7 +210,7 @@ def __call__(self, *args, **kwargs): pass -class EthosuDepthwiseConv2DRewriter(DFPatternCallback): +class DepthwiseConv2DRewriter(DFPatternCallback): """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d operators""" @@ -286,14 +288,342 @@ def callback( @ir.transform.module_pass(opt_level=1) -class LegalizeEthosUDepthwiseConv2D: - """This is the pass that wraps the EthosUDepthwiseConv2DRewriter""" +class LegalizeDepthwiseConv2D: + """This is the pass that wraps the DepthwiseConv2DRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): - func = rewrite(EthosuDepthwiseConv2DRewriter(), func) + func = rewrite(DepthwiseConv2DRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class PoolingRewriter(DFPatternCallback): + """Convert ethosu.avgpool2d and ethosu.maxpool2d composite functions to + ethosu_pooling operators""" + + def __init__( + self, + params_class: Type, + pattern: CallPattern, + ): + super().__init__(require_type=True) + self.params_class = params_class + self.pattern = pattern + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # Activations requiring LUT is currently not supported, so setting it to an empty list + lut = relay.const([], dtype="int8") + + return ethosu_ops.ethosu_pooling( + ifm=post.args[0], + lut=lut, + pooling_type=params.pooling_type, + ifm_scale=params.ifm.q_params.scale_f32, + ifm_zero_point=params.ifm.q_params.zero_point, + ofm_scale=params.ofm.q_params.scale_f32, + ofm_zero_point=params.ofm.q_params.zero_point, + pool_shape=params.pool_shape, + ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=params.strides, + padding=params.padding, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + + +class MaxPoolingRewriter(PoolingRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MaxPool2DParams, + pattern=( + wildcard().has_attr({"Composite": ethosu_patterns.MaxPool2DParams.composite_name}) + )(wildcard()), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMaxPooling: + """This is the pass that wraps the MaxPoolingRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MaxPoolingRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class AvgPoolingRewriter(PoolingRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.AvgPool2DParams, + pattern=( + wildcard().has_attr({"Composite": ethosu_patterns.AvgPool2DParams.composite_name}) + )(wildcard()), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeAvgPooling: + """This is the pass that wraps the AvgPoolingRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(AvgPoolingRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class BinaryElementwiseRewriter(DFPatternCallback): + """Convert ethosu binary elementwise composite functions to + ethosu_binary_elementwise operators""" + + def __init__( + self, + params_class: Type, + pattern: CallPattern, + ): + super().__init__(require_type=True) + self.params_class = params_class + self.pattern = pattern + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[1] if params.reversed_operands else post.args[0] + params.ifm2.tensor = post.args[0] if params.reversed_operands else post.args[1] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # We don't yet support activation functions that need to get legalized to LUTs. + lut = relay.const([], dtype="int8") + + return ethosu_ops.ethosu_binary_elementwise( + ifm=params.ifm.tensor, + ifm2=params.ifm2.tensor, + lut=lut, + operator_type=params.operator_type, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ifm2_scale=float(params.ifm2.q_params.scale_f32), + ifm2_zero_point=int(params.ifm2.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=params.ifm.shape[3], + ifm2_channels=params.ifm2.shape[3], + reversed_operands=params.reversed_operands, + ofm_dtype=params.ofm.dtype, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + ifm_layout=str(params.ifm.layout), + ifm2_layout=str(params.ifm2.layout), + ofm_layout=str(params.ofm.layout), + ) + + +class AddRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.AddParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AddParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeAdd: + """This is the pass that wraps the AddRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(AddRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class SubRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.SubParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.SubParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeSub: + """This is the pass that wraps the SubRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(SubRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MulRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MulParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MulParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMul: + """This is the pass that wraps the MulRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MulRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MinRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MinParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MinParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMin: + """This is the pass that wraps the MinRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MinRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MaxRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MaxParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MaxParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMax: + """This is the pass that wraps the MaxRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MaxRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class ShlRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.ShlParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.ShlParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeShl: + """This is the pass that wraps the ShlRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(ShlRewriter(), func) mod.update_func(global_var, func) return mod @@ -311,9 +641,20 @@ class LegalizeEthosU: def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: + """This is the method that replaces the operations with hardware/codegen supported + operations. + """ mod = LegalizeSplit()(mod) - mod = LegalizeEthosUConv2D()(mod) - mod = LegalizeEthosUDepthwiseConv2D()(mod) + mod = LegalizeConv2D()(mod) + mod = LegalizeDepthwiseConv2D()(mod) + mod = LegalizeMaxPooling()(mod) + mod = LegalizeAvgPooling()(mod) + mod = LegalizeAdd()(mod) + mod = LegalizeSub()(mod) + mod = LegalizeMul()(mod) + mod = LegalizeMin()(mod) + mod = LegalizeMax()(mod) + mod = LegalizeShl()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index 1063db6a04c5..05d405304589 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -18,3 +18,5 @@ from .convolution import ethosu_conv2d from .depthwise import ethosu_depthwise_conv2d +from .pooling import ethosu_pooling +from .binary_elementwise import ethosu_binary_elementwise diff --git a/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py new file mode 100644 index 000000000000..d4ae18b52974 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operators for binary elementwise operators for Arm(R) Ethos(TM)-U NPU""" +from typing import Optional +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import binary_elementwise_compute + + +def _extract_ethosu_binary_elementwise_params(attrs, args): + """Get the parameters necessary to construct a ethosu_binary_elementwise compute TE + from a ethosu_binary_elementwise Relay call.""" + ifm = args[0] + ifm2 = args[1] + lut = args[2] + operator_type = attrs.operator_type + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + ifm2_scale = attrs.ifm2_scale + ifm2_zero_point = attrs.ifm2_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + ifm_channels = attrs.ifm_channels + ifm2_channels = attrs.ifm2_channels + reversed_operands = attrs.reversed_operands + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + ifm_layout = attrs.ifm_layout + ifm2_layout = attrs.ifm2_layout + ofm_layout = attrs.ofm_layout + ofm_dtype = attrs.ofm_dtype + + return ( + ifm, + ifm2, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ifm2_scale, + ifm2_zero_point, + ofm_scale, + ofm_zero_point, + ifm_channels, + ifm2_channels, + reversed_operands, + activation, + clip_min, + clip_max, + ifm_layout, + ifm2_layout, + ofm_layout, + ofm_dtype, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.binary_elementwise", "FTVMCompute") +def create_ethosu_binary_elementwise_compute(attrs, args, out_type): + """Create an ethosu_binary_elementwise compute op.""" + params = _extract_ethosu_binary_elementwise_params(attrs, args) + op = binary_elementwise_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.binary_elementwise", "FTVMStrategy") +def binary_elementwise_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_binary_elementwise_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_binary_elementwise", + ) + return strategy + + +def ethosu_binary_elementwise( + ifm: tvm.relay.Expr, + ifm2: tvm.relay.Expr, + lut: tvm.relay.Expr, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ifm2_scale: float, + ifm2_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ifm_channels: int, + ifm2_channels: int, + reversed_operands: bool, + ofm_dtype: str, + activation: Optional[str] = "NONE", + clip_min: Optional[int] = 0, + clip_max: Optional[int] = 0, + ifm_layout: Optional[str] = "NHWC", + ifm2_layout: Optional[str] = "NHWC", + ofm_layout: Optional[str] = "NHWC", +) -> tvm.relay.Call: + """This is a quantized binary elementwise operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format + for the input data. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + ifm2 : tvm.relay.Expr + The Input Feature Map tensor 2 (IFM2). + lut : tvm.relay.Expr + The look-up table of values to use if activation = "LUT". + operator_type: str + The type of the binary elementwise operator. + "ADD" + "SUB" + "MUL" + "MIN" + "MAX" + "SHR" + "SHL" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ifm2_scale : float + The quantization scale for the Input Feature Map tensor 2. + ifm2_zero_point : int + The quantization zero point for the Input Feature Map tensor 2. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ifm_channels : int + The number of the Input Feature Map channels. + ifm2_channels : int + The number of the Input Feature Map 2 channels. + reversed_operands : bool + True if IFM2 is the first operand and IFM is the second operand. + ofm_dtype: str + The Output Feature Map tensor type. + MUL, ADD, SUB {IFM}->{OFM}: + {uint8, int8 int32} -> {uint8, int8, int32}, any pairing + MAX, MIN: + IFM and OFM must be of the same type, one of: + {int8, uint8} + SHR {IFM}->{OFM}: + {int32}->{int8, uint8, int32}, any pairing" + SHL: + {int32}->{int32} only + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + Available activations for activation type: + {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT" + {int32}: "NONE" + clip_min : int, optional + The minimum clipping value if activation = "CLIP". + clip_max : int, optional + The maximum clipping value if activation = "CLIP". + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm2_layout : str, optional + The layout of the Input Feature Map tensor 2. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_binary_elementwise op. + """ + return _make.ethosu_binary_elementwise( + ifm, + ifm2, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ifm2_scale, + ifm2_zero_point, + ofm_scale, + ofm_zero_point, + ifm_channels, + ifm2_channels, + reversed_operands, + activation, + clip_min, + clip_max, + ifm_layout, + ifm2_layout, + ofm_layout, + ofm_dtype, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py index b159830ceaa9..970e366e5040 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -112,7 +112,7 @@ def ethosu_conv2d( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D convolution operation as supported by the + """This is a quantized 2D convolution operation as supported by the NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. @@ -132,7 +132,7 @@ def ethosu_conv2d( scale_bias : tvm.relay.Expr The packed per-channel weight scale and bias tensor. lut : tvm.relay.Expr - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -146,7 +146,7 @@ def ethosu_conv2d( kernel_shape : tuple of int The 2 dimensional kernel shape as (kernel_height, kernel_width). ofm_channels : int - The number of OFM channels. + The number of the Output Feature Map channels. strides : tuple of int, optional The 2 dimensional strides as (stride_height, stride_width). padding : tuple of int, optional diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py index abcddf90b97c..d8f2e8b3106c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument -"""Relay operator for depthwise convolution""" +"""Relay operator for depthwise convolution for Arm(R) Ethos(TM)-U NPU""" + from typing import Tuple import tvm @@ -112,8 +113,8 @@ def ethosu_depthwise_conv2d( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D depthwise convolution operation as supported - by the NPU. It accepts either NHWC or NHCWB16 format + """This is a quantized 2D depthwise convolution operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ @@ -132,7 +133,7 @@ def ethosu_depthwise_conv2d( scale_bias : tvm.relay.Expr The packed per-channel weight scale and bias tensor. lut : tvm.relay.Expr - The look-up table values to use if activation = "LUT" + The look-up table of values to use if activation = "LUT" ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -146,7 +147,7 @@ def ethosu_depthwise_conv2d( kernel_shape : tuple of int The 2 dimensional kernel shape as (kernel_height, kernel_width). ofm_channels : int - The number of OFM channels. + The number of the Output Feature Map channels. strides : tuple of int, optional The 2 dimensional strides as (stride_height, stride_width). padding : tuple of int, optional diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py new file mode 100644 index 000000000000..cc363738c37f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operators for pooling for Arm(R) Ethos(TM)-U NPU""" +from typing import Tuple + +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import pooling_compute + + +def _extract_ethosu_pooling_params(attrs, args): + """Get the parameters necessary to construct a ethosu_pooling compute TE + from a ethosu_pooling Relay call.""" + ifm = args[0] + lut = args[1] + pooling_type = attrs.pooling_type + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + pool_shape = attrs.pool_shape + ofm_channels = attrs.ofm_channels + strides = attrs.strides + padding = attrs.padding + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + upscale = attrs.upscale + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + lut, + pooling_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + pool_shape, + ofm_channels, + strides, + padding, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.pooling", "FTVMCompute") +def create_ethosu_pooling_compute(attrs, args, out_type): + """Create an ethosu_pooling compute op.""" + params = _extract_ethosu_pooling_params(attrs, args) + op = pooling_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.pooling", "FTVMStrategy") +def pooling_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_pooling_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_pooling", + ) + return strategy + + +def ethosu_pooling( + ifm: tvm.relay.Expr, + lut: tvm.relay.Expr, + pooling_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + pool_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int] = (1, 1), + padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + activation: str = "NONE", + clip_min: int = 0, + clip_max: int = 0, + upscale: str = "NONE", + ifm_layout: str = "NHWC", + ofm_layout: str = "NHWC", +) -> tvm.relay.Call: + """This is a quantized 2D pooling operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format + for the input data. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + lut : tvm.relay.Expr + The look-up table of values to use if activation = "LUT". + pooling_type: str + The type of the pooling. "AVG" - average pool, "MAX" - max pool. + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + pool_shape : tuple of int + The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). + ofm_channels : int + The number of the Output Feature Map channels + strides : tuple of int, optional + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple of int, optional + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP". + clip_max : int, optional + The maximum clipping value if activation = "CLIP". + upscale: str, optional + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_pooling op. + """ + return _make.ethosu_pooling( + ifm, + lut, + pooling_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + pool_shape, + ofm_channels, + strides, + padding, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index 5dcdd4dcf602..5c262362e4f4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -18,3 +18,5 @@ from .convolution import * from .depthwise import * +from .pooling import * +from .binary_elementwise import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py new file mode 100644 index 000000000000..84d4e1b41558 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for binary_elementwise""" +import operator +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def binary_elementwise_compute( + ifm: te.Tensor, + ifm2: te.Tensor, + lut: te.Tensor, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ifm2_scale: float, + ifm2_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ifm_channels: int, + ifm2_channels: int, + reversed_operands: bool, + activation: str, + clip_min: int, + clip_max: int, + ifm_layout: str, + ifm2_layout: str, + ofm_layout: str, + ofm_dtype: str, +) -> te.Tensor: + """A compute operator representing the capabilities of binary_elementwise for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + ifm2 : te.Tensor + The Input Feature Map tensor 2 (IFM2). + lut : te.Tensor + The look-up table values to use if activation = "LUT". + operator_type: str + The type of the binary elementwise operator. + "ADD" + "SUB" + "MUL" + "MIN" + "MAX" + "SHR" + "SHL" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ifm2_scale : float + The quantization scale for the Input Feature Map tensor 2. + ifm2_zero_point : int + The quantization zero point for the Input Feature Map tensor 1. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ifm_channels : int + The number of the Input Feature Map channels. + ifm2_channels : int + The number of the Input Feature Map 2 channels. + reversed_operands : bool + True if IFM2 is the first operand and IFM is the second operand. + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + Available activations for activation type: + {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT" + {int32}: "NONE" + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm2_layout : str, optional + The layout of the Input Feature Map tensor 2. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_dtype: str + The Output Feature Map tensor type. + MUL, ADD, SUB {IFM}->{OFM}: + {uint8, int8 int32} -> {uint8, int8, int32}, any pairing + MAX, MIN: + IFM and OFM must be of the same type, one of: + {int8, uint8} + SHR {IFM}->{OFM}: + {int32}->{int8, uint8, int32}, any pairing" + SHL: + {int32}->{int32} only + + Returns + ------- + te.Tensor + The Output Feature Map tensor. + """ + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute( + ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0) + ) + dmaed_ifm2 = dma_ifm_compute( + ifm2, ifm2_layout, ifm2_zero_point, ifm2_scale, ifm2_channels, (0, 0, 0, 0) + ) + + # Binary elementwise compute operation + ofm_height = dmaed_ifm.shape[1] + ofm_width = dmaed_ifm.shape[2] + + binary_elementwise_attrs = { + "op": "ethosu_binary_elementwise", + "operator_type": operator_type, + "reversed_operands": reversed_operands, + "activation": activation, + "clip_min": clip_min, + "clip_max": clip_max, + } + + operators = { + "ADD": operator.add, + "SUB": operator.sub, + "MUL": operator.mul, + "MIN": te.min, + "MAX": te.max, + "SHR": operator.add, + "SHL": operator.add, + } + broadcast = [value == 1 for value in dmaed_ifm2.shape] + + if reversed_operands: + binary_elementwise = te.compute( + (1, ofm_height, ofm_width, ifm_channels), + lambda nn, hh, ww, cc: operators[operator_type]( + dmaed_ifm2( + 0 if broadcast[0] else nn, + 0 if broadcast[1] else hh, + 0 if broadcast[2] else ww, + 0 if broadcast[3] else cc, + ).astype(ifm.dtype), + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype), + ).astype(ofm_dtype), + name="ethosu_binary_elementwise", + attrs=binary_elementwise_attrs, + ) + else: + binary_elementwise = te.compute( + (1, ofm_height, ofm_width, ifm_channels), + lambda nn, hh, ww, cc: operators[operator_type]( + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype), + dmaed_ifm2( + 0 if broadcast[0] else nn, + 0 if broadcast[1] else hh, + 0 if broadcast[2] else ww, + 0 if broadcast[3] else cc, + ).astype(ifm.dtype), + ).astype(ofm_dtype), + name="ethosu_binary_elementwise", + attrs=binary_elementwise_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ifm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 26f7ea979219..1a7f96ace8eb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -53,7 +53,7 @@ def conv2d_compute( scale_bias : te.Tensor The packed per-channel weight scale and bias tensor. lut : te.Tensor - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index 35ae7f9a700a..6c139c958fa1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -53,7 +53,7 @@ def depthwise_conv2d_compute( scale_bias : te.Tensor The packed per-channel weight scale and bias tensor. lut : te.Tensor - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py new file mode 100644 index 000000000000..2f090f289da2 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for poolings""" +from typing import Tuple + +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def pooling_compute( + ifm: te.Tensor, + lut: te.Tensor, + pooling_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + pool_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int], + padding: Tuple[int, int, int, int], + activation: str, + clip_min: int, + clip_max: int, + upscale: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: + """A compute operator representing the capabilities of pooling for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + lut : te.Tensor + The look-up table of values to use if activation = "LUT". + pooling_type: str + The type of the pooling. "AVG" - average pool, "MAX" - max pool. + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + pool_shape : Tuple[int, int] + The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). + ofm_channels : int + The number of the Output Feature Map channels + strides : Tuple[int, int] + The 2 dimensional strides as (stride_height, stride_width). + padding : Tuple[int, int, int, int] + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + upscale : str + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + """ + stride_h, stride_w = strides + pool_shape_h, pool_shape_w = pool_shape + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) + + # Pooling compute operation + ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1 + ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1 + rh = te.reduce_axis((0, pool_shape_h), name="ry") + rw = te.reduce_axis((0, pool_shape_w), name="rx") + + pooling_attrs = { + "op": "ethosu_pooling", + "pooling_type": pooling_type, + "stride_h": stride_h, + "stride_w": stride_w, + "activation": activation, + "clip_min": clip_min, + "clip_max": clip_max, + "upscale": upscale, + } + + pooling = te.compute( + (1, ofm_height, ofm_width, ofm_channels), + lambda nn, hh, ww, cc: te.max( + dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype), + axis=[rh, rw], + ), + name="ethosu_pooling", + attrs=pooling_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py new file mode 100644 index 000000000000..1ea24edccb60 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the binary_elementwise operators in TIR.""" +from typing import Dict, Tuple +import tvm +from .utils import get_outer_loops, get_op_attrs +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialActivation, SerialBinaryElementwise + + +def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var: + """When the datatype of the ifm, ifm2 and ofm do not match, + casts are inserted in TE to handle the difference in these types. + Since TIR is not directly run on the NPU we can simply ignore + these, and allow the NPU to handle the difference in datatypes + itself. + + Parameters + ---------- + tir_load : tvm.tir.expr.Load + + Returns + ------- + tvm.tir.Var + """ + return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load + + +def get_binary_elementwise_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a binary_elementwise. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a binary elementwise loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialBinaryElementwise + The parameters needed to construct a binary elementwise operator. + output_pointer : tvm.tir.Var + The output pointer of the binary elementwise operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the binary elementwise output pointer. + """ + attrs, body = get_op_attrs(stmt) + reversed_operands = attrs["reversed_operands"] + + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + op = ignore_cast(inner.value) + input_pointer = ignore_cast(op.a).buffer_var + input_pointer1 = ignore_cast(op.b).buffer_var + + if reversed_operands: + input_pointer, input_pointer1 = input_pointer1, input_pointer + output_pointer = inner.buffer_var + # Get feature map info + serial_ifm, _ = get_ifm_params(input_pointer, producers) + serial_ifm2, _ = get_ifm_params(input_pointer1, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + SerialBinaryElementwise( + ifm=serial_ifm, + ifm2=serial_ifm2, + ofm=serial_ofm, + operator_type=attrs["operator_type"], + reversed_operands=reversed_operands, + activation=serial_activation, + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c792ade06643..b68a5ad14a6f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function - that consists of a sequence of tir.extern_calls to NPU + that consists of a sequence of tir.call_extern to NPU operations. Parameters @@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) + mod = tvm.tir.transform.Simplify()(mod) mod = tvm.tir.transform.LoopPartition()(mod) mod = RemoveZeroStores()(mod) mod = tvm.tir.transform.Simplify()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 761c8aad7bb1..a5678d1cc2d1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -22,6 +22,8 @@ from tvm.relay.backend.contrib.ethosu import vela_api from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params +from .pooling import get_pooling_params +from .binary_elementwise import get_binary_elementwise_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -54,6 +56,8 @@ def ReplaceOperators(): "ethosu_conv2d": get_conv2d_params, "ethosu_copy": get_copy_params, "ethosu_depthwise_conv2d": get_depthwise_conv2d_params, + "ethosu_pooling": get_pooling_params, + "ethosu_binary_elementwise": get_binary_elementwise_params, } pointer_to_producer = {} pointer_to_consumer = {} diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py new file mode 100644 index 000000000000..30f9bb3d981e --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the pooling operators in TIR.""" +from typing import Dict, Tuple +import tvm +from .utils import get_outer_loops, get_op_attrs +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialKernel, SerialActivation, SerialPooling + + +def get_pooling_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a pooling. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convolution loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialPooling + The parameters needed to construct a 2D convolution. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + compute = rw.body.value.b + input_pointer = compute.buffer_var + output_pointer = rw.body.buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=1, + dilation_h=1, + ) + + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + SerialPooling( + ifm=serial_ifm, + ofm=serial_ofm, + pooling_type=attrs["pooling_type"], + pool_shape=serial_kernel, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index ff019c7783db..269238a157ef 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -261,3 +261,24 @@ def __init__( self.padding = padding self.activation = activation self.upscale = upscale + + +class SerialBinaryElementwise(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.binary_elementwise tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ifm2: SerialFeatureMap, + ofm: SerialFeatureMap, + operator_type: str, + reversed_operands: bool, + activation: SerialActivation, + ): + self.ifm = ifm + self.ifm2 = ifm2 + self.ofm = ofm + self.operator_type = operator_type + self.reversed_operands = reversed_operands + self.activation = activation diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index bcae01a10214..f82d7bb857a6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -85,10 +85,10 @@ def translate(tir_module, params): """ buffer_info = extract_buffer_info(tir_module, params) - extern_calls = extract_extern_calls(tir_module) + call_extern_list = extract_call_extern_list(tir_module) _npu_ops = list() - for extern_call in extern_calls: - _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) + for call_extern in call_extern_list: + _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) target_accel_config = vela_api.get_accelerator_config() cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) @@ -97,7 +97,7 @@ def translate(tir_module, params): return payload.hex(), hex_value, scratch_size -def extract_extern_calls(mod): +def extract_call_extern_list(mod): """This function will obtain all extern calls from a TIR module Parameters @@ -115,14 +115,14 @@ def extract_extern_calls(mod): assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] - extern_calls = list() + call_extern_list = list() - def populate_extern_calls(stmt): + def populate_call_extern_list(stmt): if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": - extern_calls.append(stmt) + call_extern_list.append(stmt) - stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) - return extern_calls + stmt_functor.post_order_visit(primfunc.body, populate_call_extern_list) + return call_extern_list def extract_buffer_info( @@ -213,7 +213,10 @@ def replace_npu_fm_with_address(npu_fm): buffer = npu_fm.tiles.addresses[0].buffer_var assert buffer in buffer_addresses.keys() address, buffer_type = buffer_addresses[buffer] - npu_fm.tiles.addresses[0] = address + index = npu_fm.tiles.addresses[0].index * ( + np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 + ) + npu_fm.tiles.addresses[0] = address + int(index) npu_fm.region = _REGION_MAP[buffer_type] return npu_fm @@ -295,18 +298,20 @@ def classify_io(buffer): return npu_ops, constant_tensor, scratch_size -def translate_ethosu_tir_extern_call(tir_extern_call): +def translate_ethosu_tir_call_extern(tir_call_extern): """This is a dispatcher function to dispatch correct translation call depending on the extern call's first argument""" - supported_extern_calls = { + supported_call_extern = { "ethosu_conv2d": translate_ethosu_conv2d, "ethosu_copy": translate_ethosu_copy, "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d, + "ethosu_pooling": translate_ethosu_pooling, + "ethosu_binary_elementwise": translate_ethosu_binary_elementwise, } - ext_call_type = tir_extern_call.args[0].value - assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" - npu_op = supported_extern_calls[ext_call_type](tir_extern_call) + ext_call_type = tir_call_extern.args[0].value + assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is not yet supported" + npu_op = supported_call_extern[ext_call_type](tir_call_extern) # Some conversions return additional outputs # if they are needed, the caller should use the function directly if isinstance(npu_op, tuple): @@ -314,20 +319,21 @@ def translate_ethosu_tir_extern_call(tir_extern_call): return npu_op -def translate_ethosu_copy(tir_extern_call): - """This function will translate a tir ethosu_copy extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_copy(tir_call_extern: tvm.tir.Call) -> vapi.NpuDmaOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + Parameters ---------- - tir_extern_call : tvm.tir.Call + tir_call_extern : tvm.tir.Call Returns ------- ethosu.vela.api.NpuDmaOperation The vela object containing the params of ethosu_copy """ - # We skip the first element as it is the extern_call function name - serial_object = spec.create_serial_object(spec.SerialCopy, tir_extern_call.args[1:]) + # We skip the first element as it is the call_extern function name + serial_object = spec.create_serial_object(spec.SerialCopy, tir_call_extern.args[1:]) return _create_npu_dma_op(serial_object) @@ -360,7 +366,7 @@ def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv Parameters ---------- tir_call_extern : tvm.tir.Call - This should be a TIR call_extern that has a agreed upon ordering + This should be a TIR call_extern that has agreed upon ordering for TIR Compiler. See Serial2DConvolution in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. @@ -370,7 +376,6 @@ def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv The vela object containing the params of ethosu_conv2d weights_zero_point : int The zero point of the weights - """ # We skip the first element as it is the call_extern function name serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_call_extern.args[1:]) @@ -417,25 +422,27 @@ def _create_npu_op_conv2d( return npu_conv2d_op, weights_zero_point -def translate_ethosu_depthwise_conv2d(tir_extern_call): - """This function will translate a tir extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_depthwise_conv2d( + tir_call_extern: tvm.tir.Call, +) -> Tuple[vapi.NpuConvDepthWiseOperation, int]: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. Parameters ---------- - tir_extern_call : tvm.tir.Call - This should be a tir external call that has an agreed upon ordering - for NPU TIR Compiler. See Serial2DDepthwise in + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See Serial2DDepthwise in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. Returns ------- - ethosu.vela.api.NpuDepthWiseOperation + ethosu.vela.api.NpuConvDepthWiseOperation The vela object containing the params of ethosu_depthwise_conv2d weights_zero_point : int The zero point of the weights """ - serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_extern_call.args[1:]) + serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_call_extern.args[1:]) return _create_npu_op_depthwise_conv2d(serial_object) @@ -479,6 +486,7 @@ def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.N } layout = str(serial_feature_map.layout.value) data_type = str(serial_feature_map.data_type.value) + date_type_bytes = np.iinfo(np.dtype(data_type)).bits // 8 assert layout in layout_map.keys() assert data_type in datatype_map.keys() nfm = vapi.NpuFeatureMap() @@ -504,9 +512,9 @@ def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.N ) nfm.layout = layout_map[layout] nfm.strides = vapi.NpuShape3D( - int(serial_feature_map.stride_h), - int(serial_feature_map.stride_w), - int(serial_feature_map.stride_c), + int(serial_feature_map.stride_h.value) * date_type_bytes, + int(serial_feature_map.stride_w.value) * date_type_bytes, + int(serial_feature_map.stride_c.value) * date_type_bytes, ) return nfm @@ -625,3 +633,115 @@ def _create_npu_dma_op(serial_copy): length=int(serial_copy.length.value), ) return vapi.NpuDmaOperation(src, dest) + + +def translate_ethosu_pooling(tir_call_extern: tvm.tir.Call) -> vapi.NpuPoolingOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + + Parameters + ---------- + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See SerialPooling in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuPoolingOperation + The vela object containing the params of ethosu_pooling + """ + serial_object = spec.create_serial_object(spec.SerialPooling, tir_call_extern.args[1:]) + return _create_npu_op_pooling(serial_object) + + +def _create_npu_op_pooling(serial_pooling: spec.SerialPooling): + pooling_type = serial_pooling.pooling_type + if pooling_type == "AVG": + npu_pooling_op = vapi.NpuPoolingOp.AVERAGE + elif pooling_type == "MAX": + npu_pooling_op = vapi.NpuPoolingOp.MAX + + npu_pooling_op = vapi.NpuPoolingOperation(npu_pooling_op) + npu_pooling_op.ifm = _create_npu_feature_map(serial_pooling.ifm) + npu_pooling_op.ofm = _create_npu_feature_map(serial_pooling.ofm) + npu_pooling_op.kernel = _create_npu_kernel(serial_pooling.pool_shape) + npu_pooling_op.padding = _create_npu_padding(serial_pooling.padding) + + npu_pooling_op.activation = _create_npu_activation(serial_pooling.activation) + if ( + npu_pooling_op.activation + and npu_pooling_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_pooling_op) + + npu_pooling_op.upscale = _create_npu_resampling_mode(serial_pooling.upscale) + + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_pooling_op, target_accel_config) + npu_pooling_op.block_config = block_config + + return npu_pooling_op + + +def translate_ethosu_binary_elementwise( + tir_call_extern: tvm.tir.Call, +) -> vapi.NpuElementWiseOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + + Parameters + ---------- + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See SerialBinaryElementwise in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuElementWiseOperation + The vela object containing the params of ethosu_binary_elementwise + """ + serial_object = spec.create_serial_object( + spec.SerialBinaryElementwise, tir_call_extern.args[1:] + ) + return _create_npu_op_binary_elementwise(serial_object) + + +def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBinaryElementwise): + operator_type = serial_binary_elementwise.operator_type + if operator_type == "ADD": + op = vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + op = vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + op = vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + op = vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + op = vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + op = vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + op = vapi.NpuElementWiseOp.SHL + + npu_binary_elementwise_op = vapi.NpuElementWiseOperation(op) + npu_binary_elementwise_op.ifm = _create_npu_feature_map(serial_binary_elementwise.ifm) + npu_binary_elementwise_op.ifm2 = _create_npu_feature_map(serial_binary_elementwise.ifm2) + npu_binary_elementwise_op.ofm = _create_npu_feature_map(serial_binary_elementwise.ofm) + npu_binary_elementwise_op.reversed_operands = serial_binary_elementwise.reversed_operands + + npu_binary_elementwise_op.activation = _create_npu_activation( + serial_binary_elementwise.activation + ) + if ( + npu_binary_elementwise_op.activation + and npu_binary_elementwise_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_binary_elementwise_op) + + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_binary_elementwise_op, target_accel_config) + npu_binary_elementwise_op.block_config = block_config + + return npu_binary_elementwise_op diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index ee47e4abd42b..8afb6eb9b9ee 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -75,6 +75,21 @@ class ClipArgs(Enum): A_MAX = 2 +class BinaryElementwiseArgs(Enum): + """This is a helper enums to access the correct index + of binary elementwise arguments + """ + + ifm = 0 + ifm2 = 1 + ifm_scale = 2 + ifm_zero_point = 3 + ifm2_scale = 4 + ifm2_zero_point = 5 + ofm_scale = 6 + ofm_zero_point = 7 + + def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to diff --git a/python/tvm/relay/backend/executor.py b/python/tvm/relay/backend/executor.py new file mode 100644 index 000000000000..b3af565fe69e --- /dev/null +++ b/python/tvm/relay/backend/executor.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=len-as-condition,no-else-return,invalid-name +"""Executor configuration""" + +import tvm +from tvm.runtime import Object + +from . import _backend + + +@tvm._ffi.register_object +class Executor(Object): + """Executor configuration""" + + def __init__(self, name, options=None) -> None: + if options is None: + options = {} + self.__init_handle_by_constructor__(_backend.CreateExecutor, name, options) + self._attrs = _backend.GetExecutorAttrs(self) + + def __contains__(self, name): + return name in self._attrs + + def __getitem__(self, name): + return self._attrs[name] + + @staticmethod + def list_executors(): + """Returns a list of possible executors""" + return list(_backend.ListExecutors()) + + @staticmethod + def list_executor_options(executor): + """Returns the dict of available option names and types""" + return dict(_backend.ListExecutorOptions(str(executor))) diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 701ca06a87e0..7b147b440f40 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -73,6 +73,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule): Attributes ---------- + ir_mod : :py:class:`~tvm.IRModule` + The IR module to build. target : tvm.Target The Target used to build this module. libmod : tvm.Module @@ -110,11 +112,13 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule): Attributes ---------- + ir_mod : :py:class:`~tvm.IRModule` + The IR module to build. + target : tvm.Target + The Target used to build this module. graph_json_str : the json graph to be deployed in json format output by graph compiler. The graph can contain operator(tvm_op) that points to the name of PackedFunc in the libmod. - target : tvm.Target - The Target used to build this module. libmod : tvm.Module The module of the corresponding function libmod_name: str diff --git a/python/tvm/relay/backend/runtime.py b/python/tvm/relay/backend/runtime.py new file mode 100644 index 000000000000..81779a245dde --- /dev/null +++ b/python/tvm/relay/backend/runtime.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=len-as-condition,no-else-return,invalid-name +"""Runtime configuration""" + +import tvm +from tvm.runtime import Object + +from . import _backend + + +@tvm._ffi.register_object +class Runtime(Object): + """Runtime configuration""" + + def __init__(self, name, options=None) -> None: + if options is None: + options = {} + self.__init_handle_by_constructor__(_backend.CreateRuntime, name, options) + self._attrs = _backend.GetRuntimeAttrs(self) + + def __contains__(self, name): + return name in self._attrs + + def __getitem__(self, name): + return self._attrs[name] + + @staticmethod + def list_runtimes(): + """Returns a list of possible runtimes""" + return list(_backend.ListRuntimes()) + + @staticmethod + def list_runtime_options(runtime): + """Returns the dict of available option names and types""" + return dict(_backend.ListRuntimeOptions(str(runtime))) diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/te_compiler.py similarity index 79% rename from python/tvm/relay/backend/compile_engine.py rename to python/tvm/relay/backend/te_compiler.py index e9129db7b200..db7504915887 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/te_compiler.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=len-as-condition,no-else-return,invalid-name -"""Backend code generation engine.""" +"""TE compiler engine (replacing legacy compile_engine).""" from __future__ import absolute_import import logging -import numpy as np import tvm from tvm import te, autotvm from tvm.ir.transform import PassContext @@ -31,7 +30,7 @@ from .. import ty as _ty from . import _backend -logger = logging.getLogger("compile_engine") +logger = logging.getLogger("te_compiler") autotvm_logger = logging.getLogger("autotvm") _first_warning = True @@ -47,7 +46,7 @@ def __init__(self, outputs, implement): @tvm._ffi.register_object("relay.CCacheKey") class CCacheKey(Object): - """Key in the CompileEngine. + """Key in the TE Compiler. Parameters ---------- @@ -64,7 +63,7 @@ def __init__(self, source_func, target): @tvm._ffi.register_object("relay.CCacheValue") class CCacheValue(Object): - """Value in the CompileEngine, including usage statistics.""" + """Value in the TE Compiler, including usage statistics.""" def _get_cache_key(source_func, target): @@ -79,24 +78,6 @@ def _get_cache_key(source_func, target): return source_func -def get_shape(shape): - """Convert the shape to correct dtype and vars.""" - ret = [] - for dim in shape: - if isinstance(dim, tvm.tir.IntImm): - if libinfo()["INDEX_DEFAULT_I64"] == "ON": - ret.append(dim) - else: - val = int(dim) - assert val <= np.iinfo(np.int32).max - ret.append(tvm.tir.IntImm("int32", val)) - elif isinstance(dim, tvm.tir.Any): - ret.append(te.var("any_dim", "int32")) - else: - ret.append(dim) - return ret - - def get_valid_implementations(op, attrs, inputs, out_type, target): """Get all valid implementations from the op strategy. @@ -275,6 +256,24 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) return best_plevel_impl, outputs[best_plevel_impl] +def get_shape(shape): + """Convert the shape to correct dtype and vars.""" + ret = [] + for dim in shape: + if isinstance(dim, tvm.tir.IntImm): + if libinfo()["INDEX_DEFAULT_I64"] == "ON": + ret.append(dim) + else: + val = int(dim) + assert val <= np.iinfo(np.int32).max + ret.append(tvm.tir.IntImm("int32", val)) + elif isinstance(dim, tvm.tir.Any): + ret.append(te.var("any_dim", "int32")) + else: + ret.append(dim) + return ret + + @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" @@ -322,12 +321,12 @@ def lower_call(call, inputs, target): return LoweredOutput(outputs, best_impl) -@tvm._ffi.register_object("relay.CompileEngine") -class CompileEngine(Object): - """CompileEngine to get lowered code.""" +@tvm._ffi.register_object("relay.TECompiler") +class TECompiler(Object): + """TECompiler to get lowered code.""" def __init__(self): - raise RuntimeError("Cannot construct a CompileEngine") + raise RuntimeError("Cannot construct a TECompiler") def lower(self, source_func, target=None, mod_name="default"): """Lower a source_func to a CachedFunc. @@ -349,7 +348,7 @@ def lower(self, source_func, target=None, mod_name="default"): try: mod_name = mangle_module_name(mod_name) key = _get_cache_key(source_func, target) - return _backend._CompileEngineLower(self, key, mod_name) + return _backend._TECompilerLower(self, key, mod_name) except Exception: import traceback @@ -360,10 +359,6 @@ def lower(self, source_func, target=None, mod_name="default"): msg += "--------------------------\n" raise RuntimeError(msg) - def lower_shape_func(self, source_func, target=None): - key = _get_cache_key(source_func, target) - return _backend._CompileEngineLowerShapeFunc(self, key) - def jit(self, source_func, target=None): """JIT a source_func to a tvm.runtime.PackedFunc. @@ -381,87 +376,30 @@ def jit(self, source_func, target=None): The result of jited function. """ key = _get_cache_key(source_func, target) - return _backend._CompileEngineJIT(self, key) + return _backend._TECompilerJIT(self, key) def clear(self): """clear the existing cached functions""" - _backend._CompileEngineClear(self) + _backend._TECompilerClear(self) def items(self): """List items in the cache. - Returns ------- item_list : List[Tuple[CCacheKey, CCacheValue]] The list of items. """ - res = _backend._CompileEngineListItems(self) - assert len(res) % 2 == 0 - return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - - def shape_func_items(self): - """List items in the shape_func_cache. - - Returns - ------- - item_list : List[Tuple[CCacheKey, CCacheValue]] - The list of shape_func_items. - """ - res = _backend._CompileEngineListShapeFuncItems(self) + res = _backend._TECompilerListItems(self) assert len(res) % 2 == 0 return [(res[2 * i], res[2 * i + 1]) for i in range(len(res) // 2)] - def get_current_ccache_key(self): - return _backend._CompileEngineGetCurrentCCacheKey(self) - - def dump(self): - """Return a string representation of engine dump. - - Returns - ------- - dump : str - The dumped string representation - """ - items = self.items() - res = "====================================\n" - res += "CompilerEngine dump, %d items cached\n" % len(items) - for k, v in items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - shape_func_items = self.shape_func_items() - res += "%d shape_func_items cached\n" % len(shape_func_items) - for k, v in shape_func_items: - res += "------------------------------------\n" - res += "target={}\n".format(k.target) - res += "use_count={}\n".format(v.use_count) - res += "func_name={}\n".format(v.cached_func.prim_fn_var.name_hint) - res += "----relay function----\n" - res += k.source_func.astext() + "\n" - res += "----tir function----- \n" - res += "inputs={}\n".format(v.cached_func.inputs) - res += "outputs={}\n".format(v.cached_func.outputs) - res += "function: \n" - res += v.cached_func.funcs.astext() + "\n" - res += "===================================\n" - return res - def get(): - """Get the global compile engine. + """Get the global TE Compiler. Returns ------- - engine : tvm.relay.backend.CompileEngine - The compile engine. + engine : tvm.relay.backend.TECompiler + The TE Compiler. """ - return _backend._CompileEngineGlobal() + return _backend._TECompilerGlobal() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 363ff893df8b..1dde27f172b1 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -275,14 +275,18 @@ def __init__(self, mod, device, target): self.mod = mod self.device = device self.target = target - self.executable = compile(mod, target) - self.vm = vm_rt.VirtualMachine(self.executable, device) + self.executable = None + self.vm = None def _make_executor(self, expr=None): - main = self.mod["main"] + if expr: + self.mod["main"] = expr + + self.executable = compile(self.mod, self.target) + self.vm = vm_rt.VirtualMachine(self.executable, self.device) def _vm_wrapper(*args, **kwargs): - args = self._convert_args(main, args, kwargs) + args = self._convert_args(self.mod["main"], args, kwargs) return self.vm.run(*args) return _vm_wrapper diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index b8273b0324c0..30327e580884 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -21,11 +21,12 @@ import numpy as np import tvm from tvm.ir import IRModule + +from ... import nd as _nd from .. import analysis from .. import expr as _expr from .. import function as _function from .. import op as _op -from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape @@ -50,6 +51,7 @@ def __init__(self, init_layer_dict, predict_layer, exp_tab): "Deconvolution": self.convert_deconv, "Dropout": self.convert_dropout, "Eltwise": self.convert_eltwise, + "Embed": self.convert_embed, "Flatten": self.convert_flatten, "InnerProduct": self.convert_innerproduct, "Input": None, @@ -513,6 +515,9 @@ def convert_deconv(self, op): weight_shape = [-1, conv_params.num_output, kh, kw] weight_value = np.asarray(weight.data, np.float32) weight_value = np.reshape(weight_value, weight_shape) + + # weight shape is in relay's IOHW format rn, we need it to be OIHW + weight_value = np.transpose(weight_value, [1, 0, 2, 3]) else: raise Exception("No weight value of layer {} in caffemodel".format(op.name)) @@ -520,7 +525,6 @@ def convert_deconv(self, op): in_expr = self.exp_tab.get_expr(inputs[0]) out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) if bias: - bias_value = np.asarray(bias.data, np.float32) bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") out = _op.nn.bias_add(out, bias_expr) @@ -593,6 +597,46 @@ def convert_crop(self, op): out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis) return out + def convert_embed(self, op): + """Convert Embed layer""" + inputs = op.bottom + embed_param = op.embed_param + num_output = embed_param.num_output + input_dim = embed_param.input_dim + bias_term = embed_param.bias_term + weight_bias_blobs = self.init_layer_dict[op.name].blobs + weight, bias = None, None + if bias_term: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + assert weight and bias + else: + weight = weight_bias_blobs[0] + assert weight + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, [input_dim, num_output]) + weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") + in_expr = self.exp_tab.get_expr(inputs[0]) + input_shape = _infer_shape(in_expr) + input_count = 1 + for dim in input_shape: + input_count *= dim + + index = _op.cast(in_expr, "int32") + out = _op.take(weight_expr, index, axis=0) + + if bias_term: + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") + out = _op.reshape(out, [input_count, num_output]) + out = _op.add(out, bias_expr) + + out_shape = list(input_shape) + out_shape.append(num_output) + out = _op.reshape(out, out_shape) + + return out + def check_unsupported_ops(self): """Check unsupported Caffe ops in our converter.""" unsupported_ops_set = set() diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 825a586918f8..737830986d74 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -28,6 +28,7 @@ from .. import function as _function from .. import transform as _transform from .. import op as _op +from .. import ty as _ty from .. import analysis # pylint: disable=invalid-name @@ -577,14 +578,15 @@ def infer_value_simulated(input_val, params): return output_value -def try_infer_value(val, on_success=None, on_failure=None): +def try_infer_value(val, on_success=None, on_failure=None, parameters=None): """Try running infer_value on the input val, and if successful, return the inferred value or pass it to on_success callback if provided. Otherwise, run on_failure callback if it is provided, or return the input val as output. In each case, the second return value indicates whether infer_value has succeeded or not. """ try: - ret = infer_value(val, {}).numpy() + params = parameters if parameters is not None else {} + ret = infer_value(val, params).numpy() if on_success: return on_success(ret), True return ret, True @@ -594,6 +596,16 @@ def try_infer_value(val, on_success=None, on_failure=None): return val, False +def shape_of(x, dtype="int64"): + """Get shape of a tensor.""" + + ttype = infer_type(x).checked_type + if not _ty.is_dynamic(ttype): + shape = list(ttype.shape) + return _expr.const(shape, dtype) + return _op.shape_of(x, dtype) + + def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): return _expr.var(name_hint, type_annotation, shape, dtype) @@ -837,6 +849,69 @@ def lstm_cell( return outputs_list, hidden_state, cell_state +def autopad( + data, + strides, + kernel_shape, + dilations=(1, 1), + pad_type="constant", + deconv=False, + mode="SAME_UPPER", + pad_value=0.0, +): + """ + Perform autopadding with dynamic input shapes + """ + # get attributes as constants + strides = _op.const(np.array(strides), dtype="int64") + dilated_kernel_shape = _op.const( + np.array( + [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] + ), + dtype="int64", + ) + # get input shape + ndim = len(infer_shape(data)) + shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) + + # set up integer constants + zero = _op.const(0, dtype="int64") + one = _op.const(1, dtype="int64") + two = _op.const(2, dtype="int64") + + # Calculate total padding + mod = _op.mod(shape, strides) + + left = _op.maximum(dilated_kernel_shape - strides, zero) + right = _op.maximum(dilated_kernel_shape - mod, zero) + + total_pad = _op.where(_op.equal(mod, zero), left, right) + if deconv: + total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad + + # split total padding into before and after + pad_before = _op.floor_divide(total_pad, two) + pad_after = total_pad - pad_before + + # combine + if "LOWER" in mode: + pad = _op.concatenate( + [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 + ) + else: + pad = _op.concatenate( + [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 + ) + + # pad N and C with zeros + pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) + + if isinstance(pad_value, (float, int)): + pad_value = _op.const(pad_value) + + return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) + + def ensure_scalar_shape(x): """ Assume that `x` is a tensor with one element (regardless of tensor rank). diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 4a0336a056b3..901eee33575d 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, import-self, import-outside-toplevel """Keras frontend.""" +import dis import sys import numpy as np import tvm @@ -991,10 +992,110 @@ def _convert_repeat_vector(inexpr, keras_layer, _): out_shape = [-1, repeats] + input_shape[1:] out = _op.repeat(inexpr, repeats=repeats, axis=0) out = _op.reshape(out, out_shape) - return out +def _convert_l2_normalize(inexpr, keras_layer, etab): + l2_normalize_is_loaded = False + param_list = [] + for i in dis.get_instructions(keras_layer.function): + if i.opname in ["LOAD_GLOBAL", "LOAD_DEREF"]: + continue + if i.opname in ["LOAD_ATTR", "LOAD_METHOD"]: + if i.argval == "l2_normalize": + assert not l2_normalize_is_loaded, "l2_normalize was already LOADED" + l2_normalize_is_loaded = True + elif i.opname in ["LOAD_CONST", "LOAD_FAST"] and l2_normalize_is_loaded: + param_list.append(i.argval) + elif i.opname == "BUILD_LIST": + sz = i.argval + assert len(param_list) >= sz + new_list = param_list[-sz:] + param_list = param_list[:-sz] + param_list.append(new_list) + elif i.opname in ["CALL_FUNCTION_KW", "CALL_METHOD"]: + break + + axis = None + is_param_list_parsed = False + if l2_normalize_is_loaded and len(param_list) > 0: + # last param_list item is tuple of strings means that + # lambda uses named parameters when calling l2_normalize + if ( + isinstance(param_list[-1], tuple) + and len(param_list[-1]) > 0 + and isinstance(param_list[-1][0], str) + ): + param_names = param_list[-1] + if len(param_names) == 1 and param_names[0] == "x": + # lambda v: K.l2_normalize(x=v) + axis = None + is_param_list_parsed = True + elif len(param_names) == 1 and param_names[0] == "axis" and len(param_list) == 3: + # lambda x: K.l2_normalize(x, axis=(2,3)) + axis = param_list[1] + is_param_list_parsed = True + elif len(param_names) == 2 and len(param_list) == 3: + # lambda x: K.l2_normalize(x=x, axis=(2,3)) + # lambda x: K.l2_normalize(axis=(2,3), x=x) + axis = param_list[param_names.index("axis")] + is_param_list_parsed = True + else: + # lambda x: K.l2_normalize(x) + if len(param_list) == 1: + axis = None + is_param_list_parsed = True + # lambda x: K.l2_normalize(x, (2,3)) + elif len(param_list) == 2: + axis = param_list[1] + is_param_list_parsed = True + + def is_int_or_tuple_of_ints(v): + if isinstance(v, list) and len(v) > 0: + for i in v: + if not isinstance(i, int): + return False + return True + if isinstance(v, tuple) and len(v) > 0: + return isinstance(v[0], int) + return isinstance(v, int) + + assert is_param_list_parsed and ( + axis is None or is_int_or_tuple_of_ints(axis) + ), "Can not parse l2_normalize lambda function found in Lambda layer" + if isinstance(axis, int): + axis = [axis] + + if etab.data_layout == "NCHW": + dims = len(keras_layer.input_shape) + + def fix_axis_for_nchw(axis): + if axis == 0: + return 0 + if axis in [(dims - 1), -1]: + return 1 + return axis + 1 + + axis = [fix_axis_for_nchw(x) for x in axis] + return _op.nn.l2_normalize(inexpr, eps=1e-12, axis=axis) + + +def _convert_lambda(inexpr, keras_layer, etab): + fcode = keras_layer.function.__code__ + # Convert l2_normalize + if ( + fcode.co_name == "" + and len(fcode.co_names) > 0 + and fcode.co_names[-1] == "l2_normalize" + ): + return _convert_l2_normalize(inexpr, keras_layer, etab) + raise tvm.error.OpNotImplemented( + "Function {} used in Lambda layer is not supported in frontend Keras.".format( + fcode.co_names + ) + ) + + def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument """Layers that can be skipped because they are train time only.""" return inexpr @@ -1059,6 +1160,7 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument "Permute": _convert_permute, "Embedding": _convert_embedding, "RepeatVector": _convert_repeat_vector, + "Lambda": _convert_lambda, "InputLayer": _default_skip, "Dropout": _default_skip, "AlphaDropout": _default_skip, diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5c112c7dfce0..5813f6305ace 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,6 +38,7 @@ from .. import ty as _ty from .. import vision as _vision from .common import ( + autopad, AttrCvt, Renamer, ensure_scalar_shape, @@ -51,6 +52,7 @@ infer_value, lstm_cell, new_var, + shape_of, try_resolve_var_to_const, unbind, ) @@ -315,7 +317,6 @@ def _run_calculation(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], [1] * ndim, - ndim, pad_value=pad_val, mode=attr["auto_pad"], ) @@ -411,69 +412,6 @@ def _impl_v1(cls, inputs, attr, params): return AttrCvt(op_name="instance_norm")(inputs, attr, params) -def autopad( - data, - strides, - kernel_shape, - dilations, - ndim, - pad_type="constant", - deconv=False, - mode="SAME_UPPER", - pad_value=0.0, -): - """ - Perform autopadding with dynamic input shapes - """ - # get attributes as constants - strides = _op.const(np.array(strides), dtype="int64") - dilated_kernel_shape = _op.const( - np.array( - [(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)] - ), - dtype="int64", - ) - # get input shape - shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim]) - - # set up integer constants - zero = _op.const(0, dtype="int64") - one = _op.const(1, dtype="int64") - two = _op.const(2, dtype="int64") - - # Calculate total padding - mod = _op.mod(shape, strides) - - left = _op.maximum(dilated_kernel_shape - strides, zero) - right = _op.maximum(dilated_kernel_shape - mod, zero) - - total_pad = _op.where(_op.equal(mod, zero), left, right) - if deconv: - total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - total_pad - - # split total padding into before and after - pad_before = _op.floor_divide(total_pad, two) - pad_after = total_pad - pad_before - - # combine - if "LOWER" in mode: - pad = _op.concatenate( - [_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1 - ) - else: - pad = _op.concatenate( - [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1 - ) - - # pad N and C with zeros - pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0) - - if isinstance(pad_value, (float, int)): - pad_value = _op.const(pad_value) - - return _op.nn.pad(data, fold_constant(pad), pad_value, pad_type) - - class Conv(OnnxOpConverter): """Operator converter for Conv.""" @@ -501,7 +439,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -582,7 +519,6 @@ def _impl_v1(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, deconv=True, mode=attr["auto_pad"], ) @@ -974,7 +910,6 @@ def _impl_v1(cls, inputs, attr, params): attr["strides"], attr["kernel_shape"], [1] * ndim, - ndim, mode=attr["auto_pad"], ) elif attr["auto_pad"] == "VALID": @@ -1410,14 +1345,6 @@ def _impl_v9(cls, inputs, attr, params): return out -def shape_of(x, dtype="int64"): - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(shape, dtype) - return _op.shape_of(x, dtype) - - class Shape(OnnxOpConverter): """Operator converter for Shape.""" @@ -1534,9 +1461,8 @@ class Split(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): splits = attr.get("split", None) - if splits is not None: + if splits is not None and len(splits) > 1: indices = [] - attr["indices_or_sections"] = [] index = 0 for i in splits[:-1]: index += i @@ -3440,7 +3366,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=x_zero_point.data, mode=attr["auto_pad"], ) @@ -3810,7 +3735,6 @@ def _impl_v10(cls, inputs, attr, params): attr.get("strides", [1] * (ndim - 2)), attr["kernel_shape"], attr.get("dilations", [1] * (ndim - 2)), - ndim, pad_value=data_zp, mode=attr["auto_pad"], ) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index c32449546f77..967238552b24 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -18,6 +18,7 @@ # pylint: disable=import-outside-toplevel """Paddle: PArallel Distributed Deep LEarning.""" +import warnings import numpy as np import tvm @@ -31,11 +32,13 @@ from .. import ty as _ty from .. import op as _op from .common import ( + autopad, fold_constant, get_relay_op, infer_shape, infer_type, infer_value, + shape_of, try_infer_value, new_var, ) @@ -43,20 +46,6 @@ __all__ = ["from_paddle"] -def _get_pad_size(in_size, dilated_kernel_size, stride_size): - """Calculate the paddings size for Conv/Pool in SAME padding mode.""" - - if stride_size == 1 or in_size % stride_size == 0: - pad = max(dilated_kernel_size - stride_size, 0) - else: - pad = max(dilated_kernel_size - (in_size % stride_size), 0) - - pad_before = pad // 2 - pad_after = pad - pad_before - - return [pad_before, pad_after] - - def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" @@ -78,16 +67,6 @@ def _dtype_shape_promotion(inputs): return inputs -def shape_of(x, dtype="int32"): - """Get shape of a tensor.""" - - ttype = infer_type(x).checked_type - if not _ty.is_dynamic(ttype): - shape = list(ttype.shape) - return _expr.const(np.array(shape), dtype) - return _op.shape_of(x, dtype) - - def _convert_dtype_value(val): """Converts a Paddle type id to a string.""" @@ -136,6 +115,32 @@ def convert_binary_logical_op(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_addmm(g, op, block): + """Operator converter for addmm.""" + + input_x = g.get_node(op.input("Input")[0]) + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + + alpha = op.attr("Alpha") + beta = op.attr("Beta") + dtype = block.var(op.output("Out")[0]).dtype + dtype = _convert_dtype_value(dtype) + + if not isinstance(alpha, _expr.Expr) and alpha != 1: + alpha = _expr.const(alpha, dtype) + x *= alpha + + if not isinstance(beta, _expr.Expr) and beta != 1: + beta = _expr.const(beta, dtype) + input_x *= beta + + transposed_y = _op.transpose(y, axes=[1, 0]) + dense_out = _op.nn.dense(x, transposed_y) + out = dense_out + input_x + g.add_node(op.output("Out")[0], out) + + def convert_arg_max_min(g, op, block): """Operator converter for arg_max and arg_min.""" @@ -213,6 +218,26 @@ def convert_batch_norm(g, op, block): g.add_node(op.output("Y")[0], out[0]) +def convert_bmm(g, op, block): + """Operator converter for bmm.""" + + x = g.get_node(op.input("X")[0]) + y = g.get_node(op.input("Y")[0]) + y = _op.transpose(y, [0, 2, 1]) + out = _op.nn.batch_matmul(x, y) + g.add_node(op.output("Out")[0], out) + + +def convert_brelu(g, op, block): + """Operator converter for brelu.""" + + x = g.get_node(op.input("X")[0]) + t_max = op.attr("t_max") + t_min = op.attr("t_min") + out = _op.tensor.clip(x, t_min, t_max) + g.add_node(op.output("Out")[0], out) + + def convert_cast(g, op, block): """Operator converter for cast.""" @@ -248,24 +273,16 @@ def convert_conv2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - if strides[0] == 1 and strides[1] == 1: - pad_h = _get_pad_size(0, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(0, (k_w - 1) * dilations[1] + 1, strides[1]) - else: - input_shape = shape_of(input_x) - h_w = _op.strided_slice(input_shape, [2], [4]) - try: - in_h, in_w = infer_value(h_w, g.get_params()).numpy().tolist() - except Exception as e: - msg = "Dynamic shape is not supported in SAME padding algorithm while stride!=1" - raise tvm.error.OpAttributeInvalid(msg) from e - pad_h = _get_pad_size(in_h, (k_h - 1) * dilations[0] + 1, strides[0]) - pad_w = _get_pad_size(in_w, (k_w - 1) * dilations[1] + 1, strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + # Handle history issue of PaddlePaddle + # while padding_algorithm == "SAME" + # dilations will be set to [1, 1] + dilations = [1, 1] + input_x = autopad(input_x, strides, [k_h, k_w], dilations) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Conv is not "valid."' @@ -365,10 +382,12 @@ def convert_expand(g, op, block): x = g.get_node(op.input("X")[0]) if op.input("Shape"): sizes = g.get_node(op.input("Shape")[0]) - sizes = try_infer_value(sizes, g.get_params())[0] else: sizes = op.attr("shape") + if isinstance(sizes, _expr.Expr): + sizes = try_infer_value(sizes, parameters=g.get_params())[0] + if isinstance(sizes, np.ndarray): sizes = sizes.tolist() @@ -430,10 +449,11 @@ def convert_fill_constant(g, op, block): value = _expr.const(value).astype(dtype) if "ValueTensor" in op.input_names and op.input("ValueTensor"): shape = g.get_node(op.input("ValueTensor")[0]) - shape = try_infer_value(shape, g.get_params())[0] if "ShapeTensor" in op.input_names and op.input("ShapeTensor"): shape = g.get_node(op.input("ShapeTensor")[0]) - shape = try_infer_value(shape, g.get_params())[0] + + if isinstance(shape, _expr.Expr): + shape = try_infer_value(shape, parameters=g.get_params())[0] if isinstance(shape, np.ndarray): shape = shape.tolist() @@ -442,6 +462,56 @@ def convert_fill_constant(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_flatten(g, op, block): + """Operator converter for flatten.""" + + x = g.get_node(op.input("X")[0]) + input_shape = list(infer_shape(x)) + + start = op.attr("start_axis") + end = op.attr("stop_axis") + ndim = len(input_shape) + if end < 0: + end += ndim + new_shape = [0] * start + + new_shape.append(-1) + squeeze_axes = [] + for i in range(start + 1, end + 1): + new_shape.append(1) + squeeze_axes.append(i) + for _ in range(end + 1, ndim): + new_shape.append(0) + out = _op.reshape(x, new_shape) + if squeeze_axes: + out = _op.squeeze(out, axis=squeeze_axes) + + g.add_node(op.output("Out")[0], out) + + +def convert_gather(g, op, block): + """Operator converter for gather.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + axis = op.attr("axis") + out = _op.take(x, index, axis) + g.add_node(op.output("Out")[0], out) + + +def convert_gather_nd(g, op, block): + """Operator converter for gather_nd.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + shape = infer_shape(index) + perm = list(range(0, len(shape) - 1)) + perm.insert(0, len(shape) - 1) + index = _op.transpose(index, axes=perm) + out = _op.gather_nd(x, index, 0, shape[-1]) + g.add_node(op.output("Out")[0], out) + + def convert_gelu(g, op, block): """Operator converter for gelu.""" @@ -453,6 +523,39 @@ def convert_gelu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_group_norm(g, op, block): + """Operator converter for group_norm.""" + + x = g.get_node(op.input("X")[0]) + num_groups = op.attr("groups") + epsilon = op.attr("epsilon") + gamma = g.get_node(op.input("Scale")[0]) + beta = g.get_node(op.input("Bias")[0]) + out = _op.nn.group_norm( + x, + gamma=gamma, + beta=beta, + num_groups=num_groups, + axis=1, + epsilon=epsilon, + center=True, + scale=True, + ) + g.add_node(op.output("Y")[0], out) + + +def convert_hard_shrink(g, op, block): + """Operator converter for hard_shrink.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + threshold = op.attr("threshold") + threshold = _op.const(threshold, dtype) + out = _op.logical_or(x < _op.const(-1.0, dtype) * threshold, x > threshold) + out = _op.cast(out, dtype) * x + g.add_node(op.output("Out")[0], out) + + def convert_hard_sigmoid(g, op, block): """Operator converter for hard_sigmoid.""" @@ -479,6 +582,99 @@ def convert_hard_swish(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_interpolate(g, op, block): + """Operator converter for interpolate.""" + + def get_interpolate_mode(op): + """Get parameters for interpolation methods.""" + + interp_method = op.attr("interp_method") + align_corners = op.attr("align_corners") + align_mode = op.attr("align_mode") + + rounding_method = "" + if interp_method == "nearest": + interp_method = "nearest_neighbor" + coordinate_transformation_mode = "asymmetric" + rounding_method = "floor" + elif interp_method == "bilinear": + interp_method = "linear" + if not align_corners and align_mode == 0: + coordinate_transformation_mode = "half_pixel" + else: + if align_corners: + coordinate_transformation_mode = "align_corners" + else: + coordinate_transformation_mode = "asymmetric" + elif interp_method == "bicubic": + interp_method = "cubic" + if align_corners: + coordinate_transformation_mode = "align_corners" + else: + coordinate_transformation_mode = "half_pixel" + else: + msg = "interp_method {} is not supported for PaddlePaddle's interpolate" + raise tvm.error.OpAttributeInvalid(msg.format(interp_method)) + return rounding_method, interp_method, coordinate_transformation_mode + + layout = op.attr("data_layout") + out_h = op.attr("out_h") + out_w = op.attr("out_w") + + x = g.get_node(op.input("X")[0]) + x_shape = infer_shape(x) + assert len(x_shape) == 4, "Only 4D input tensor is supported for PaddlePaddle's interpolate" + input_out_size = op.input("OutSize") + input_size_tensor = op.input("SizeTensor") + input_scale = op.input("Scale") + rounding_method, interp_method, coordinate_transformation_mode = get_interpolate_mode(op) + + if input_out_size: + # if out_size is a tensor + out_size = g.get_node(input_out_size[0]) + out_size, infered = try_infer_value(out_size, parameters=g.get_params()) + if infered: + out_size = out_size.tolist() + elif input_size_tensor: + # if out_size is a list of tensor + out_size = list() + for name in input_size_tensor: + size = g.get_node(name) + if len(infer_shape(size)) == 0: + shape = _op.reshape(shape, [-1]) + out_size.append(size) + out_size = _op.concatenate(out_size, axis=0) + out_size, infered = try_infer_value(out_size, parameters=g.get_params()) + if infered: + out_size = out_size.tolist() + elif input_scale: + # if out_size is not defined, but scale is defined + input_scale = g.get_node(input_scale[0]) + input_shape = shape_of(x).astype("float32") + if layout.startswith("NC"): + out_size = _op.strided_slice(input_shape, begin=[2], end=[4]) * input_scale + else: + out_size = _op.strided_slice(input_shape, begin=[1], end=[3]) * input_scale + out_size = out_size.astype("int32") + out_size, infered = try_infer_value(out_size, parameters=g.get_params()) + if infered: + out_size = out_size.tolist() + else: + # if out_size is a constant value + out_size = [out_h, out_w] + + out = _op.image.resize2d( + x, + size=out_size, + layout=layout, + method=interp_method, + coordinate_transformation_mode=coordinate_transformation_mode, + rounding_method=rounding_method, + cubic_alpha=-0.75, + ) + g.add_node(op.output("Out")[0], out) + + def convert_layer_norm(g, op, block): """Operator converter for layer_norm.""" @@ -519,6 +715,15 @@ def convert_leaky_relu(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_logical_not(g, op, block): + """Operator converter for logical_not op.""" + + ipt0 = g.get_node(op.input("X")[0]) + op_func = get_relay_op(op.type) + out = op_func(ipt0) + g.add_node(op.output("Out")[0], out) + + def convert_lookup_table(g, op, block): """Operator converter for lookup_table_v2.""" @@ -559,9 +764,9 @@ def convert_matmul(g, op, block): # This implemention almost keeps same with ONNX # Need to check input shape as batch matmul must be supported. - a_shape = shape_of(inputs[0]) + a_shape = shape_of(inputs[0], dtype="int32") a_rank = infer_shape(a_shape)[0] - b_shape = shape_of(inputs[1]) + b_shape = shape_of(inputs[1], dtype="int32") b_rank = infer_shape(b_shape)[0] # When performing a batch matmul, we need to properly handle N-dim shapes. if a_rank > 2 or b_rank > 2: @@ -648,8 +853,8 @@ def convert_mul(g, op, block): y = g.get_node(op.input("Y")[0]) x_num_col_dims = op.attr("x_num_col_dims") y_num_col_dims = op.attr("y_num_col_dims") - x_shape = shape_of(x) - y_shape = shape_of(y) + x_shape = shape_of(x, dtype="int32") + y_shape = shape_of(y, dtype="int32") x_dim = infer_shape(x_shape)[0] y_dim = infer_shape(y_shape)[0] if x_num_col_dims < 0: @@ -686,6 +891,39 @@ def convert_mul(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_padding(g, op, block): + """Operator converter for padding.""" + + input_x = g.get_node(op.input("X")[0]) + input_padding = op.input("Paddings") + if input_padding: + padding = g.get_node(input_padding[0]) + padding = infer_value(padding, g.get_params()).numpy().tolist() + else: + padding = op.attr("paddings") + padding = op.attr("paddings") + value = op.attr("value") + data_format = op.attr("data_format") + mode = op.attr("mode") + assert mode != "circular", "Don't support mod='circular' for PaddlePaddle's padding" + if mode == "replicate": + mode = "edge" + + pad_len = len(padding) + new_paddings = [0] * (pad_len + 4) + for i in range(0, pad_len, 2): + index = -1 - i + if data_format[:2] != "NC": + index = -3 - i + new_paddings[index] = padding[i + 1] + new_paddings[index - 1] = padding[i] + + new_paddings = [new_paddings[i : i + 2] for i in range(0, len(new_paddings), 2)] + + out = _op.nn.pad(input_x, new_paddings, pad_value=value, pad_mode=mode) + g.add_node(op.output("Out")[0], out) + + def convert_pool2d(g, op, block): """Operator converter for pool2d.""" @@ -696,17 +934,19 @@ def convert_pool2d(g, op, block): paddings = op.attr("paddings") padding_algorithm = op.attr("padding_algorithm") pooling_type = op.attr("pooling_type") + if global_pooling: adaptive = True ksize = [1, 1] input_x = g.get_node(op.input("X")[0]) - in_h, in_w = infer_shape(input_x)[2:] + _, _, in_h, in_w = infer_shape(input_x) op_map = { "avg": "avg_pool2d", "max": "max_pool2d", } + strides = op.attr("strides") if isinstance(strides, int): strides = [strides, strides] @@ -718,27 +958,101 @@ def convert_pool2d(g, op, block): if padding_algorithm == "VALID": paddings = [0, 0] elif padding_algorithm == "SAME": - pad_h = _get_pad_size(in_h, ksize[0], strides[0]) - pad_w = _get_pad_size(in_w, ksize[1], strides[1]) - paddings = [pad_h[0], pad_w[0], pad_h[1], pad_w[1]] + input_x = autopad(input_x, strides, ksize) + paddings = [0, 0] elif padding_algorithm == "EXPLICIT": if len(paddings) == 2: paddings = [paddings[0], paddings[1], paddings[0], paddings[1]] - if len(paddings) == 4: + elif len(paddings) == 4: paddings = [paddings[0], paddings[2], paddings[1], paddings[3]] else: msg = 'Value {} in attribute "padding" of operator Pool2d is not "valid."' raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm)) + # handle with special case + # while kernel size less than input size + # shrink kernel size to input size + if not isinstance(in_h, _op.Expr) and in_h < ksize[0]: + ksize[0] = in_h + if not isinstance(in_w, _op.Expr) and in_w < ksize[1]: + ksize[1] = in_w + if not adaptive: - out = getattr(_op.nn, op_map[pooling_type])( - input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode - ) + if pooling_type == "avg": + exclusive = op.attr("exclusive") + out = _op.nn.avg_pool2d( + input_x, + pool_size=ksize, + strides=strides, + padding=paddings, + ceil_mode=ceil_mode, + count_include_pad=not exclusive, + ) + else: + out = getattr(_op.nn, op_map[pooling_type])( + input_x, pool_size=ksize, strides=strides, padding=paddings, ceil_mode=ceil_mode + ) else: out = getattr(_op.nn, "adaptive_" + op_map[pooling_type])(input_x, output_size=ksize) g.add_node(op.output("Out")[0], out) +def convert_pow(g, op, block): + """Operator converter for pow.""" + + x = g.get_node(op.input("X")[0]) + dtype = block.var(op.output("Out")[0]).dtype + dtype = _convert_dtype_value(dtype) + factor = op.attr("factor") + factor = _expr.const(factor, dtype=dtype) + out = _op.power(x, factor) + g.add_node(op.output("Out")[0], out) + + +def convert_reciprocal(g, op, block): + """Operator converter for reciprocal.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = _expr.const(1.0, dtype) / x + g.add_node(op.output("Out")[0], out) + + +def convert_reduce(g, op, block): + """Operator converter for series of reduce operators.""" + + op_map = { + "reduce_all": "all", + "reduce_any": "any", + "reduce_max": "max", + "reduce_min": "min", + "reduce_prod": "prod", + "reduce_sum": "sum", + "reduce_mean": "mean", + } + op_name = op_map[op.type] + input_x = g.get_node(op.input("X")[0]) + axis = op.attr("dim") + if op.attr("reduce_all"): + axis = None + keepdims = op.attr("keep_dim") + out = get_relay_op(op_name)(input_x, axis=axis, keepdims=keepdims) + if not axis and not keepdims: + # use `expand_dims` to solve the following situation + # for TVM, the shape of `out` will be (, ) + # for Paddle, the shape of `out` will be [1] + out = _op.expand_dims(out, axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_relu6(g, op, block): + """Operator converter for relu6.""" + + x = g.get_node(op.input("X")[0]) + out = _op.clip(x, 0.0, 6.0) + g.add_node(op.output("Out")[0], out) + + def convert_reshape(g, op, block): """Operator converter for reshape.""" @@ -748,18 +1062,16 @@ def convert_reshape(g, op, block): if input_shape: new_shape = g.get_node(input_shape[0]) elif input_shape_tensor: - tmp_shape = [] + new_shape = [] for shape_name in input_shape_tensor: shape = g.get_node(shape_name) if len(infer_shape(shape)) == 0: shape = _op.reshape(shape, [-1]) - if isinstance(shape, _expr.Constant): - tmp_shape.append(shape) - elif isinstance(shape, _expr.Expr): - tmp_shape.append(shape) - else: - tmp_shape.append(_expr.const(np.array(shape).astype("int64"))) - new_shape = _op.concatenate(tmp_shape, axis=0) + new_shape.append(shape) + new_shape = _op.concatenate(new_shape, axis=0) + new_shape, infered = try_infer_value(new_shape, parameters=g.get_params()) + if infered: + new_shape = new_shape.tolist() else: new_shape = op.attr("shape") out = _op.reshape(data, new_shape) @@ -792,11 +1104,70 @@ def convert_scale(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_scatter(g, op, block): + """Operator converter for scatter.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Ids")[0]) + updates = g.get_node(op.input("Updates")[0]) + overwrite = op.attr("overwrite") + + shape = infer_shape(updates) + ndims = len(shape) + index = _op.expand_dims(index, axis=-1, num_newaxis=ndims - 1) + index = _op.transform.broadcast_to(index, shape) + + if overwrite: + out = _op.scatter(x, index, updates, axis=0) + else: + out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0) + out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) + g.add_node(op.output("Out")[0], out) + + +def convert_scatter_nd_add(g, op, block): + """Operator converter for scatter_nd_add.""" + + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + updates = g.get_node(op.input("Updates")[0]) + indices_dim = len(infer_shape(index)) + axes = list(range(indices_dim)) + index = _op.transpose(index, axes[-1:] + axes[:-1]) + out = _op.scatter_nd(x, index, updates, mode="add") + g.add_node(op.output("Out")[0], out) + + +def convert_selu(g, op, block): + """Operator converter for selu.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + alpha = _op.const(op.attr("alpha"), dtype) + scale = _op.const(op.attr("scale"), dtype) + out = ( + _expr.const(-1.0, dtype=dtype) + * alpha + * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(x)) + ) + out = scale * (out + _op.nn.relu(x)) + g.add_node(op.output("Out")[0], out) + + def convert_shape(g, op, block): """Operator converter for shape.""" x = g.get_node(op.input("Input")[0]) - out = shape_of(x) + out = shape_of(x, dtype="int32") + g.add_node(op.output("Out")[0], out) + + +def convert_size(g, op, block): + """Operator converter for size.""" + + input_x = g.get_node(op.input("Input")[0]) + out = _op.ndarray_size(input_x, dtype="int64") + out = _op.expand_dims(out, axis=0) g.add_node(op.output("Out")[0], out) @@ -854,6 +1225,64 @@ def convert_softmax(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_softplus(g, op, block): + """Operator converter for softplus.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + beta = op.attr("beta") + beta = _expr.const(beta, dtype=dtype) + out = _op.log(_op.exp(x * beta) + _expr.const(1.0, dtype=dtype)) / beta + g.add_node(op.output("Out")[0], out) + + +def convert_softsign(g, op, block): + """Operator converter for softsign.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.abs(x)) + g.add_node(op.output("Out")[0], out) + + +def convert_square(g, op, block): + """Operator converter for square.""" + + x = g.get_node(op.input("X")[0]) + dtype = block.var(op.output("Out")[0]).dtype + dtype = _convert_dtype_value(dtype) + out = _op.power(x, _expr.const(2, dtype)) + g.add_node(op.output("Out")[0], out) + + +def convert_squeeze(g, op, block): + """Operator converter for squeeze2.""" + + x = g.get_node(op.input("X")[0]) + axes = op.attr("axes") + if not axes: + axes = None + x = _op.squeeze(x, axis=axes) + g.add_node(op.output("Out")[0], x) + + +def convert_swish(g, op, block): + """Operator converter for swish.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + out = x / (_op.const(1.0, dtype) + _op.exp(_op.const(-1.0, dtype) * x)) + g.add_node(op.output("Out")[0], out) + + +def convert_transpose(g, op, block): + """Operator converter for transpose.""" + + perm = op.attr("axis") + out = _op.transpose(g.get_node(op.input("X")[0]), axes=perm) + g.add_node(op.output("Out")[0], out) + + def convert_unsqueeze(g, op, block): """Operator converter for unsqueeze.""" @@ -865,31 +1294,59 @@ def convert_unsqueeze(g, op, block): _convert_map = { + "abs": convert_unary_op, + "acos": convert_unary_op, + "addmm": convert_addmm, "arg_max": convert_arg_max_min, "arg_min": convert_arg_max_min, "argsort": convert_argsort, + "asin": convert_unary_op, "assign": convert_assign, "assign_value": convert_assign_value, + "atan": convert_unary_op, "batch_norm": convert_batch_norm, + "bicubic_interp_v2": convert_interpolate, + "bilinear_interp_v2": convert_interpolate, + "bmm": convert_bmm, + "brelu": convert_brelu, "cast": convert_cast, + "ceil": convert_unary_op, "concat": convert_concat, "conv2d": convert_conv2d, + "cos": convert_unary_op, + "cosh": convert_unary_op, "cumsum": convert_cumsum, "depthwise_conv2d": convert_conv2d, "dot": convert_dot, "dropout": convert_dropout, "elementwise_add": convert_elementwise_op, "elementwise_div": convert_elementwise_op, + "elementwise_floordiv": convert_elementwise_op, + "elementwise_max": convert_elementwise_op, + "elementwise_min": convert_elementwise_op, + "elementwise_mod": convert_elementwise_op, "elementwise_mul": convert_elementwise_op, + "elementwise_pow": convert_elementwise_op, + "elementwise_prod": convert_elementwise_op, "elementwise_sub": convert_elementwise_op, "equal": convert_elementwise_op, + "erf": convert_unary_op, "exp": convert_unary_op, "expand_v2": convert_expand, "expand_as_v2": convert_expand_as, "feed": convert_feed, "fill_any_like": convert_fill_any_like, "fill_constant": convert_fill_constant, + "flatten_contiguous_range": convert_flatten, + "floor": convert_unary_op, + "floor_mod": convert_elementwise_op, + "gather": convert_gather, + "gather_nd": convert_gather_nd, "gelu": convert_gelu, + "greater_equal": convert_elementwise_op, + "greater_than": convert_elementwise_op, + "group_norm": convert_group_norm, + "hard_shrink": convert_hard_shrink, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, "isfinite_v2": convert_unary_op, @@ -897,21 +1354,60 @@ def convert_unsqueeze(g, op, block): "isnan_v2": convert_unary_op, "layer_norm": convert_layer_norm, "leaky_relu": convert_leaky_relu, + "less_equal": convert_elementwise_op, + "less_than": convert_elementwise_op, + "log": convert_unary_op, + "log2": convert_unary_op, + "log10": convert_unary_op, "logical_and": convert_binary_logical_op, + "logical_not": convert_logical_not, "logical_or": convert_binary_logical_op, "logical_xor": convert_binary_logical_op, "lookup_table_v2": convert_lookup_table, "matmul": convert_matmul, "matmul_v2": convert_matmul, "mul": convert_mul, + "nearest_interp_v2": convert_interpolate, + "not_equal": convert_elementwise_op, + "pad1d": convert_padding, + "pad2d": convert_padding, + "pad3d": convert_padding, "pool2d": convert_pool2d, + "pow": convert_pow, "relu": convert_unary_op, + "relu6": convert_relu6, "reshape2": convert_reshape, + "round": convert_unary_op, + "reciprocal": convert_reciprocal, + "reduce_all": convert_reduce, + "reduce_any": convert_reduce, + "reduce_max": convert_reduce, + "reduce_min": convert_reduce, + "reduce_prod": convert_reduce, + "reduce_sum": convert_reduce, + "reduce_mean": convert_reduce, + "rsqrt": convert_unary_op, "scale": convert_scale, + "scatter": convert_scatter, + "scatter_nd_add": convert_scatter_nd_add, + "selu": convert_selu, "shape": convert_shape, + "sigmoid": convert_unary_op, + "sign": convert_unary_op, + "sin": convert_unary_op, + "sinh": convert_unary_op, + "size": convert_size, "slice": convert_slice, "softmax": convert_softmax, + "softplus": convert_softplus, + "softsign": convert_softsign, + "sqrt": convert_unary_op, + "square": convert_square, + "squeeze2": convert_squeeze, + "swish": convert_swish, + "tan": convert_unary_op, "tanh": convert_unary_op, + "transpose2": convert_transpose, "unsqueeze2": convert_unsqueeze, } @@ -1062,7 +1558,6 @@ def from_translated_layer(self, layer, shape_dict): def from_paddle(program_or_layer, shape_dict=None, scope=None): """Convert a PaddlePaddle model into an equivalent Relay Function. - PaddlePaddle Program/TranslatedLayer represent the computation graph of PaddlePaddle model, and PaddlePaddle scope stores all the weights of PaddlePaddle model. @@ -1087,6 +1582,10 @@ def from_paddle(program_or_layer, shape_dict=None, scope=None): import paddle + # disable system signal capturing in paddle framework + # the signal capturing may cause conflict while running autotvm with paddle frontend + paddle.disable_signal_handler() + g = GraphProto() if isinstance(program_or_layer, paddle.jit.TranslatedLayer): # model is loaded by `paddle.jit.load` diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b3aecb589352..a17a10e7b398 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -849,35 +849,23 @@ def hard_swish(self, inputs, input_types): data = inputs[0] return data * self.hard_sigmoid(inputs, input_types) - def adaptive_avg_pool_2d(self, inputs, input_types): + def adaptive_avg_pool(self, op, inputs, input_types): data = inputs[0] output_size = inputs[1] def func(x): - return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) + return op(x, output_size=output_size) if self.is_quantized_tensor(data): return qnn_torch.apply_with_upcast(data, func) return func(data) - def adaptive_max_pool_2d(self, inputs, input_types): + def adaptive_max_pool(self, op, inputs, input_types): data = inputs[0] output_size = inputs[1] - # returns dummy indices too - return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None - - def adaptive_max_pool_3d(self, inputs, input_types): - data = inputs[0] - output_size = inputs[1] - # returns dummy indices too - return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None - - def adaptive_avg_pool_3d(self, inputs, input_types): - data = inputs[0] - output_size = inputs[1] - return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) + return op(data, output_size=output_size), None @staticmethod def convert_const_list(data): @@ -2797,6 +2785,39 @@ def searchsorted(self, inputs, input_types): def bucketize(self, inputs, input_types): return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3]) + def roll(self, inputs, input_types): + def slide_axes(inp, shape, ax): + axes = list(range(len(shape))) + axes = axes[:ax] + [-1] + axes[ax:-1] + return _op.transpose(inp, axes) + + x = inputs[0] + shifts = inputs[1] + dims = inputs[2] + shape = self.infer_shape(x) + start = _expr.const(0, "int64") + step = _expr.const(1, "int64") + + out = x + for i, dim in enumerate(dims): + roll_dim = _expr.const(shape[dim], "int64") + indices_1d = _op.mod( + _op.transform.arange(start, roll_dim, step, "int64") + - _expr.const(shifts[i], "int64") + + roll_dim, + roll_dim, + ) + # First fill in the last axis with roll indices, and then do transpose to + # bring the roll indices into the desired axis. + indices = slide_axes( + _op.tile(indices_1d, shape[:dim] + shape[dim + 1 :] + (1,)), + shape, + dim, + ) + out = _op.gather(out, dim, indices) + + return out + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2854,9 +2875,26 @@ def create_convert_map(self): "aten::gelu": self.gelu, "aten::selu": self.selu, "aten::silu": self.silu, + "aten::silu_": self.silu, "aten::log_sigmoid": self.log_sigmoid, - "aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d, - "aten::adaptive_max_pool2d": self.adaptive_max_pool_2d, + "aten::adaptive_avg_pool1d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool1d + ), + "aten::adaptive_avg_pool2d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool2d + ), + "aten::adaptive_avg_pool3d": functools.partial( + self.adaptive_avg_pool, _op.nn.adaptive_avg_pool3d + ), + "aten::adaptive_max_pool1d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool1d + ), + "aten::adaptive_max_pool2d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool2d + ), + "aten::adaptive_max_pool3d": functools.partial( + self.adaptive_max_pool, _op.nn.adaptive_max_pool3d + ), "aten::max_pool2d": self.maxpool_2d, "aten::max_pool2d_with_indices": self.maxpool_2d_with_indices, "aten::max_pool1d": self.maxpool_1d, @@ -2942,6 +2980,7 @@ def create_convert_map(self): "aten::rsqrt": self.make_unary("rsqrt"), "aten::ceil": self.make_unary("ceil"), "aten::floor": self.make_unary("floor"), + "aten::floor_": self.make_unary("floor"), "aten::round": self.make_unary("round"), "aten::isfinite": self.make_unary("isfinite"), "aten::isinf": self.make_unary("isinf"), @@ -2967,8 +3006,6 @@ def create_convert_map(self): "aten::bitwise_xor": self.bitwise_xor, "aten::Bool": self.Bool, "aten::Float": self.Float, - "aten::adaptive_avg_pool3d": self.adaptive_avg_pool_3d, - "aten::adaptive_max_pool3d": self.adaptive_max_pool_3d, "aten::rsub": self.rsub, "aten::embedding": self.embedding, "aten::one_hot": self.one_hot, @@ -3024,6 +3061,7 @@ def create_convert_map(self): "aten::any": functools.partial(self.all_any_common, _op.any), "aten::searchsorted": self.searchsorted, "aten::bucketize": self.bucketize, + "aten::roll": self.roll, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index a8213d4b1c49..26ea4f4dbc2a 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -461,8 +461,11 @@ def _impl(inputs, attr, params, mod): raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"])) if "kernel_layout" not in attr: - if opname in ["conv", "conv_transpose"]: + if opname == "conv": attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW" + elif opname == "conv_transpose": + # conv_transpose in TVM has weights be IOHW for NCHW + attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW" else: attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW" diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5d3681ebe132..05b3041320c9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1136,8 +1136,6 @@ def _convert_unary_elemwise(self, relay_op, op): def convert_abs(self, op): """Convert TFLite ABS""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFlite quantized ABS operator is not supported yet.") return self._convert_unary_elemwise(_op.abs, op) def convert_ceil(self, op): @@ -1194,8 +1192,6 @@ def convert_cos(self, op): def convert_sqrt(self, op): """Convert TFLite SQRT""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFlite quantized SQRT operator is not supported yet.") return self._convert_unary_elemwise(_op.sqrt, op) def convert_rsqrt(self, op): @@ -1204,8 +1200,6 @@ def convert_rsqrt(self, op): def convert_neg(self, op): """Convert TFLite NEG""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented("TFlite quantized NEG operator is not supported yet.") return self._convert_unary_elemwise(_op.negative, op) def convert_elu(self, op): @@ -2635,7 +2629,7 @@ def convert_mirror_pad(self, op): # paddings pad_list = self.get_tensor_value(input_tensors[1]) # convert list of lists to tuple of tuples - paddings = tuple(tuple(l) for l in pad_list) + paddings = tuple(tuple(l.astype(np.int32)) for l in pad_list) assert op.BuiltinOptionsType() == BuiltinOptions.MirrorPadOptions op_options = op.BuiltinOptions() diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 19162a108395..dd1a65288955 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -84,3 +84,28 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret + + +@script +def _searchsorted_shape(sorted_sequence_shape, values_shape): + out_shape = output_tensor((values_shape.shape[0],), "int64") + if sorted_sequence_shape.shape[0] > 1: + assert ( + sorted_sequence_shape.shape[0] == values_shape.shape[0] + ), "Ranks of `sorted_sequence` and values must be the same if `sorted_sequence` is not 1-D." + for i in range(values_shape.shape[0]): + if sorted_sequence_shape.shape[0] > 1 and i < values_shape.shape[0] - 1: + assert ( + sorted_sequence_shape[i] == values_shape[i] + ), "`sorted_sequence and `values` do not have the same shape along outer axes." + + out_shape[i] = values_shape[i] + return out_shape + + +@_reg.register_shape_func("searchsorted", False) +def searchsorted_shape_func(attrs, inputs, _): + """ + Shape func for searchsorted operator. + """ + return [_searchsorted_shape(inputs[0], inputs[1])] diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 30c2db0ddf0b..1dd6da6c2747 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -24,3 +24,4 @@ from .coreml import * from .ethosn import * from .tensorrt import * +from .cutlass import * diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index cf0e9156e65f..824343e0066b 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument """Arm(R) CMSIS-NN supported operators for Cortex-M.""" import tvm.ir +from tvm.target import Target from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name @@ -25,7 +26,7 @@ def enabled(): - return bool(tvm.get_global_func("relay.ext.cmsisnn", True)) + return "cmsis-nn" in Target.list_kinds() def partition_for_cmsisnn(mod, params=None, **opts): @@ -51,7 +52,7 @@ def partition_for_cmsisnn(mod, params=None, **opts): [ transform.InferType(), transform.MergeComposite(pattern_table()), - transform.AnnotateTarget("cmsisnn"), + transform.AnnotateTarget("cmsis-nn"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ] @@ -60,9 +61,9 @@ def partition_for_cmsisnn(mod, params=None, **opts): return seq(mod) -@register_pattern_table("cmsisnn") +@register_pattern_table("cmsis-nn") def pattern_table(): - """Get the cmsisnn compiler pattern table.""" + """Get the CMSIS-NN compiler pattern table.""" def softmax_pattern(): pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) @@ -104,14 +105,14 @@ def check_quantized_binary_op(extract): ) return [ - ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax), + ("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax), ( - "cmsisnn.quantized_mul", + "cmsis-nn.quantized_mul", binary_op_pattern("mul"), check_quantized_binary_op, ), ( - "cmsisnn.quantized_add", + "cmsis-nn.quantized_add", binary_op_pattern("add"), check_quantized_binary_op, ), diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py new file mode 100644 index 000000000000..8ed371844a1c --- /dev/null +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Patterns supported CUTLASS.""" +from tvm.relay import transform +from ...dataflow_pattern import wildcard, is_op, is_constant + + +def make_gelu_pattern(bias_out, out_dtype="float16"): + mul = is_op("multiply")(bias_out, is_constant() | wildcard()) + if out_dtype == "float16": + erf = is_op("cast")(is_op("erf")(is_op("cast")(mul))) + else: + erf = is_op("erf")(mul) + mul_half = is_op("multiply")(erf, is_constant() | wildcard()) + add = is_op("add")(mul_half, is_constant() | wildcard()) + return is_op("multiply")(add, bias_out) + + +def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"): + """Create a pattern for dense op followed by activations.""" + data = wildcard() + weight = wildcard() + bias = wildcard() + gemm = is_op("nn.dense")(data, weight) + if with_bias: + add_or_bias_add = is_op("add") | is_op("nn.bias_add") + gemm_out = add_or_bias_add(gemm, bias) + else: + gemm_out = gemm + + if with_act is None: + return gemm_out + if isinstance(with_act, str) and with_act == "relu": + return is_op("nn.relu")(gemm_out) + + assert isinstance(with_act, str) and with_act == "gelu" + return make_gelu_pattern(gemm_out, out_dtype) + + +def make_batch_matmul_pattern(): + return is_op("nn.batch_matmul")(wildcard(), wildcard()) + + +def partition_for_cutlass(mod): + """Partition the input module into CUTLASS-supported subgraphs.""" + dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) + dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None)) + dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu")) + dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu")) + dense_bias_gelu_fp32_pat = ( + "cutlass.dense_bias_gelu_fp32", + make_gemm_pattern(True, "gelu", out_dtype="float32"), + ) + cutlass_patterns = [ + dense_bias_gelu_fp16_pat, + dense_bias_gelu_fp32_pat, + dense_bias_relu_pat, + dense_bias_pat, + dense_pat, + ("cutlass.batch_matmul", make_batch_matmul_pattern()), + ] + mod = transform.MergeComposite(cutlass_patterns)(mod) + mod = transform.AnnotateTarget(["cutlass"])(mod) + mod = transform.PartitionGraph()(mod) + return mod diff --git a/python/tvm/relay/op/contrib/ethosn.py b/python/tvm/relay/op/contrib/ethosn.py index 39ecec7049b3..412ae713bae1 100644 --- a/python/tvm/relay/op/contrib/ethosn.py +++ b/python/tvm/relay/op/contrib/ethosn.py @@ -46,7 +46,7 @@ def ethosn_available(): return Available.SW_AND_HW if hw else Available.SW_ONLY -def partition_for_ethosn(mod, params=None, **opts): +def partition_for_ethosn77(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to Arm Ethos-N NPU. @@ -61,6 +61,49 @@ def partition_for_ethosn(mod, params=None, **opts): ------- ret : annotated and partitioned module. """ + if opts: + tops = opts.get("tops", None) + ple_ratio = opts.get("ple_ratio", None) + sram_size = opts.get("sram_size", None) + if tops or ple_ratio or sram_size: + raise ValueError( + "Setting tops, ple_ratio or sram_size has no effect when targeting Ethos(TM)-N77" + ) + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("ethos-n"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) + + +def partition_for_ethosn78(mod, params=None, **opts): + """Partition the graph greedily offloading supported + operators to Arm Ethos(TM)-N NPU. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : annotated and partitioned module. + """ + if not opts or opts.get("variant", "").lower() != "ethos-n78": + raise ValueError("When targeting Ethos(TM)-N78, -variant=Ethos-N78 should be set.") + if params: mod["main"] = bind_params_by_name(mod["main"], params) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index ca417942840d..25538cae9dbc 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -23,7 +23,7 @@ import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant # type: ignore +from tvm.relay.expr import Constant, Call # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore @@ -40,6 +40,7 @@ from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs from tvm.relay.backend.contrib.ethosu.util import get_dim_value except ImportError: vapi = None @@ -99,9 +100,8 @@ def check_strides(strides: List[int]) -> bool: return True -def check_valid_dtypes(tensor_params: List[TensorParams]) -> bool: +def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: List[type]) -> bool: """This function checks whether dtypes are supported by the NPU""" - supported_dtypes = (np.uint8, np.int8) for tep in tensor_params: # Check for dtypes if np.dtype(tep.dtype) not in supported_dtypes: @@ -170,6 +170,16 @@ def check_padding(padding: List[int], bounds: List[int]): return not (top > topb or left > leftb or bottom > bottomb or right > rightb) +def check_pool_shape(pool_shape: tvm.ir.container.Array) -> bool: + if len(pool_shape) != 2: + return False + if pool_shape[1] > 256: + return False + if pool_shape[0] * pool_shape[1] > 256 * 256: + return False + return True + + class QnnConv2DParams: """ This class will parse a Call to a ethosu.qnn_conv2d composite function @@ -238,7 +248,7 @@ def is_valid(self) -> bool: This function checks whether QnnConv2D has compatible attributes with the NPU """ tensor_params = [self.weights, self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if not check_weights(self.weights, self.dilation): return False @@ -277,7 +287,7 @@ def is_valid(self): Checks whether QnnDepthwiseConv2D + activation function has compatible attributes with HW """ tensor_params = [self.weights, self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if not check_weights(self.weights, self.dilation): return False @@ -331,6 +341,433 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return clip_or_req +class MaxPool2DParams: + """ + This class will parse a call to a ethosu.maxpool2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.maxpool2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [127, 127, 128, 128] + + def __init__(self, func_body: Call): + clip = None + if str(func_body.op) == "clip": + clip = func_body + pool_op = clip.args[0] + else: + pool_op = func_body + + attrs = pool_op.attrs + self.ifm = TensorParams(pool_op.args[0], attrs.layout) + self.ofm = TensorParams(pool_op, attrs.layout) + self.pool_shape = attrs.pool_size + self.strides = attrs.strides + self.padding = attrs.padding + self.activation = clip + self.pooling_type = "MAX" + + def is_valid(self): + """ + This function checks whether MaxPool2D has compatible attributes with the NPU + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if not check_pool_shape(self.pool_shape): + return False + return True + + +def qnn_maxpool2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for nn.max_pool2d with optional fused RELU activation. + """ + pattern = is_op("nn.max_pool2d")(wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class AvgPool2DParams: + """ + This class will parse a call to a ethosu.avgpool2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.avgpool2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [127, 127, 128, 128] + + def __init__(self, func_body: Call): + clip = None + if str(func_body.op) == "clip": + clip = func_body + cast2 = clip.args[0] + else: + cast2 = func_body + + avgpool = cast2.args[0] + cast1 = avgpool.args[0] + + attrs = avgpool.attrs + self.ifm = TensorParams(cast1.args[0], attrs.layout) + self.ofm = TensorParams(cast2, attrs.layout) + self.pool_shape = attrs.pool_size + self.strides = attrs.strides + self.padding = attrs.padding + self.activation = clip + self.pooling_type = "AVG" + + def is_valid(self): + """ + This function checks whether AvgPool2D has compatible attributes with the NPU + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if not check_pool_shape(self.pool_shape): + return False + return True + + +def qnn_avgpool2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for nn.avg_pool2d with optional fused RELU activation. + """ + pattern = is_op("cast")(wildcard()) + pattern = is_op("nn.avg_pool2d")(pattern) + pattern = is_op("cast")(pattern) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class BinaryElementwiseParams: + """ + This class will parse a call to a ethosu.binary_elementwise composite function + and extract the parameter information. + """ + + def __init__(self, func_body: Call, operator_type: str, has_quantization_parameters: bool): + clip = None + if str(func_body.op) == "clip": + clip = func_body + binary_op = clip.args[0] + else: + binary_op = func_body + + layout = "NHWC" + + if has_quantization_parameters: + self.ifm = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm.value], + layout, + binary_op.args[BinaryElementwiseArgs.ifm_scale.value], + binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value], + ) + self.ifm2 = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm2.value], + layout, + binary_op.args[BinaryElementwiseArgs.ifm2_scale.value], + binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value], + ) + self.ofm = TensorParams( + binary_op, + layout, + binary_op.args[BinaryElementwiseArgs.ofm_scale.value], + binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value], + ) + else: + self.ifm = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm.value], + layout, + ) + self.ifm2 = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm2.value], + layout, + ) + self.ofm = TensorParams( + binary_op, + layout, + ) + self.activation = clip + self.operator_type = operator_type + + def can_broadcast(x, y): + for i in range(1, 4): + if x.shape[i] == y.shape[i] or y.shape[i] == 1: + continue + return False + return True + + if can_broadcast(self.ifm, self.ifm2): + self.reversed_operands = False + self.valid_broadcast = True + elif can_broadcast(self.ifm2, self.ifm): + self.reversed_operands = True + self.ifm, self.ifm2 = self.ifm2, self.ifm + self.valid_broadcast = True + else: + self.valid_broadcast = False + + def is_valid(self): + """ + This function checks whether BinaryElementwise has compatible attributes with the NPU + """ + if np.dtype(self.ofm) == np.int32 and self.activation is not None: + return False + if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4: + return False + if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1: + return False + if not self.valid_broadcast: + return False + return True + + +class AddParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Add composite function + and extract the parameter information. + """ + + composite_name = "ethosu.add" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "ADD", True) + + def is_valid(self): + """ + This function checks whether Add has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.add with optional fused RELU activation. + """ + pattern = is_op("qnn.add")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class SubParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Sub composite function + and extract the parameter information. + """ + + composite_name = "ethosu.sub" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "SUB", True) + + def is_valid(self): + """ + This function checks whether Sub has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.subtract with optional fused RELU activation. + """ + pattern = is_op("qnn.subtract")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MulParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Mul composite function + and extract the parameter information. + """ + + composite_name = "ethosu.mul" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MUL", True) + + def is_valid(self): + """ + This function checks whether Mul has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.mul with optional fused RELU activation. + """ + pattern = is_op("qnn.mul")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MinParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Min composite function + and extract the parameter information. + """ + + composite_name = "ethosu.min" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MIN", False) + + def is_valid(self): + """ + This function checks whether Min has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if self.ifm.dtype != self.ifm2.dtype: + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] + ): + return False + return True + + +def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for minimum with optional fused RELU activation. + """ + pattern = is_op("minimum")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MaxParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Max composite function + and extract the parameter information. + """ + + composite_name = "ethosu.max" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MAX", False) + + def is_valid(self): + """ + This function checks whether Max has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if self.ifm.dtype != self.ifm2.dtype: + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] + ): + return False + return True + + +def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for maximum with optional fused RELU activation. + """ + pattern = is_op("maximum")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class ShlParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Shl composite function + and extract the parameter information. + """ + + composite_name = "ethosu.shl" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "SHL", False) + + def is_valid(self): + """ + This function checks whether Shl has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes([self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.int32]): + return False + return True + + +def shl_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for left_shift with optional fused RELU activation. + """ + pattern = is_op("left_shift")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + @register_pattern_table("ethosu") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -344,6 +781,46 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_depthwise_conv2d_pattern(), lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(), ), + ( + MaxPool2DParams.composite_name, + qnn_maxpool2d_pattern(), + lambda pat: MaxPool2DParams(pat).is_valid(), + ), + ( + AvgPool2DParams.composite_name, + qnn_avgpool2d_pattern(), + lambda pat: AvgPool2DParams(pat).is_valid(), + ), + ( + AddParams.composite_name, + qnn_add_pattern(), + lambda pat: AddParams(pat).is_valid(), + ), + ( + SubParams.composite_name, + qnn_subtract_pattern(), + lambda pat: SubParams(pat).is_valid(), + ), + ( + MulParams.composite_name, + qnn_mul_pattern(), + lambda pat: MulParams(pat).is_valid(), + ), + ( + MinParams.composite_name, + minimum_pattern(), + lambda pat: MinParams(pat).is_valid(), + ), + ( + MaxParams.composite_name, + maximum_pattern(), + lambda pat: MaxParams(pat).is_valid(), + ), + ( + ShlParams.composite_name, + shl_pattern(), + lambda pat: ShlParams(pat).is_valid(), + ), ] diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index b9d6806306f4..50f473aea1f2 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -24,7 +24,7 @@ import tvm from tvm import relay from tvm.relay.adt import Pattern -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.expr import Expr, GlobalVar, Var from tvm.relay.function import Function from tvm.relay.expr_functor import ExprFunctor @@ -61,7 +61,7 @@ def __init__(self, mod, target) -> None: super().__init__() self.mod = mod self.tgt = target - self.engine = compile_engine.get() + self.tec = te_compiler.get() self.fun_no = 0 self.var_no = 0 self.var_map = {} @@ -153,7 +153,10 @@ def parse_name(self, name: str): def parse_numpy_array(self, arr): """Given a Numpy array, produces an appropriate Python array or numerical literal representing its contents.""" - parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) + + def parse_single(i): + return NameConstant(i) if isinstance(i, bool) else Num(i) + if arr.ndim == 0: return parse_single(arr.item()) if arr.ndim == 1: @@ -240,11 +243,11 @@ def create_op_call(self, op: Function, relay_args, py_args): the generated Python code.""" # compile the function and register globally - cc_key = compile_engine.CCacheKey(op, self.tgt) + cc_key = te_compiler.CCacheKey(op, self.tgt) func_hash = tvm.ir.structural_hash(op) op_name = "_lowered_op_{}".format(func_hash) if not tvm.get_global_func(op_name, allow_missing=True): - jitted = self.engine.jit(cc_key, self.tgt) + jitted = self.tec.jit(cc_key, self.tgt) tvm.register_func(op_name, jitted) def convert_input(py_input, arg_type): diff --git a/python/tvm/rpc/server_ios_launcher.py b/python/tvm/rpc/server_ios_launcher.py new file mode 100644 index 000000000000..2e31586f6456 --- /dev/null +++ b/python/tvm/rpc/server_ios_launcher.py @@ -0,0 +1,498 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Python wrapper for running a RPC Server through iOS RPC +on the iOS simulator using the simctl command line tool. +""" +# pylint: disable=invalid-name +import os +import json +import time +import threading +import subprocess +from enum import Enum +from typing import Dict, List, AnyStr + + +class OSName(Enum): + """The names of the operating systems available on the simulator.""" + + iOS = "iOS" + tvOS = "tvOS" + watchOS = "watchOS" + + +class IOSDevice(Enum): + """The names of available iOS devices.""" + + iPhone = "iPhone" + iPod = "iPod" + iPad = "iPad" + + +class RPCServerMode(Enum): + """Server modes available in the iOS RPC application.""" + + standalone = "standalone" + proxy = "proxy" + tracker = "tracker" + + +def get_list_of_available_simulators() -> Dict[AnyStr, List]: + """ + List of simulators available on the system. Simulators are presented as a dictionary. + The dictionary key is the name of the operating system of the simulator. + The dictionary value is a list of all simulators with a given operating system. + """ + + with subprocess.Popen( + "xcrun simctl list devices available --json", + shell=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) as proc: + out, _ = proc.communicate() + available_simulators = json.loads(out)["devices"] + available_simulators = { + key: value for key, value in available_simulators.items() if value != [] + } + return available_simulators + + +def grep_by_system(available_devices: Dict[AnyStr, List], system_name: OSName) -> List[Dict]: + """Search for simulators that use the target operating system.""" + + def find_index_of_substr(search_field: List[AnyStr], target: AnyStr) -> int: + for i, item in enumerate(search_field): + if target in item: + return i + raise ValueError("Search field doesn't content target") + + keys = list(available_devices.keys()) + + return available_devices[keys[find_index_of_substr(keys, system_name.value)]] + + +def grep_by_device(available_devices: List[Dict], device_name: IOSDevice) -> List[Dict]: + """Search for simulators that emulate a given device.""" + + return [item for item in available_devices if device_name.value in item["name"]] + + +def get_device_uid(target_device: Dict) -> AnyStr: + """Get a unique device ID.""" + + return target_device["udid"] + + +def check_call_with_runtime_error(cmd: AnyStr, error_message: AnyStr) -> None: + """Calling the function `subprocess.check_call` and catching its possible thrown exception.""" + + try: + subprocess.check_call(cmd.split(" ")) + except subprocess.CalledProcessError as called_process_error: + raise called_process_error from RuntimeError(error_message) + + +def boot_device(udid: AnyStr) -> None: + """Boot the device by its unique ID.""" + + cmd = f"xcrun simctl boot {udid}" + error_message = f"Failed to boot device with unique id: {udid}" + check_call_with_runtime_error(cmd, error_message) + if not is_booted(udid): + raise RuntimeError(error_message) + + +def shutdown_device(udid: AnyStr) -> None: + """Shutdown the device by its unique ID.""" + + cmd = f"xcrun simctl shutdown {udid}" + error_message = f"Failed to shut down device with unique id: {udid}" + check_call_with_runtime_error(cmd, error_message) + if not is_turned_off(udid): + raise RuntimeError(error_message) + + +def deploy_bundle_to_simulator(udid: AnyStr, bundle_path: AnyStr) -> None: + """Deploy iOS RPC bundle to simulator with its unique ID .""" + + check_call_with_runtime_error( + cmd=f"xcrun simctl install {udid} {bundle_path}", + error_message=f"Failed to deploy bundle <{bundle_path}> to device with unique id: {udid}", + ) + + +def delete_bundle_from_simulator(udid: AnyStr, bundle_id: AnyStr) -> None: + """Delete iOS RPC bundle from simulator with its unique ID .""" + + check_call_with_runtime_error( + cmd=f"xcrun simctl uninstall {udid} {bundle_id}", + error_message=f"Failed to uninstall bundle <{bundle_id}> " + f"from device with unique id: {udid}", + ) + + +def launch_ios_rpc( + udid: AnyStr, bundle_id: AnyStr, host_url: AnyStr, host_port: int, key: AnyStr, mode: AnyStr +): # pylint: disable=too-many-arguments, consider-using-with + """ + Launch iOS RPC application on simulator with No UI interconnection. + + udid : str + Unique device ID. + + bundle_id : str + iOS RPC bundle ID. + + host_url : str + The tracker/proxy address. + + host_port : int + The tracker/proxy port. + + key : str + The key used to identify the device type in tracker. + + mode : str + Server mode. See RPCServerMode. + """ + + cmd = ( + f"xcrun simctl launch --console {udid} {bundle_id}" + f" --immediate_connect" + f" --host_url={host_url}" + f" --host_port={host_port}" + f" --key={key}" + f" --server_mode={mode}" + f" --verbose" + ) + proc = subprocess.Popen( + cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=1, + universal_newlines=True, + ) + return proc + + +def terminate_ios_rpc(udid: AnyStr, bundle_id: AnyStr) -> None: + """Terminate iOS RPC application.""" + + check_call_with_runtime_error( + cmd=f"xcrun simctl terminate {udid} {bundle_id}", + error_message=f"Failed to terminate bundle <{bundle_id}> " + f"from device with unique id: {udid}", + ) + + +def is_booted(udid: AnyStr) -> bool: + """Check that the device has booted.""" + + device = find_device(udid) + return device["state"] == "Booted" + + +def is_turned_off(udid: AnyStr) -> bool: + """Check that the device has turned off.""" + + device = find_device(udid) + return device["state"] == "Shutdown" + + +def check_booted_device(devices: List[Dict]) -> Dict: + """Check if there is already a booted device. If so, return this device.""" + + for device in devices: + if device["state"] == "Booted": + return device + return {} + + +def find_device(udid: AnyStr) -> Dict: + """Find device by its unique ID.""" + + return_value = {} + available_devices = get_list_of_available_simulators() + for devices in available_devices.values(): + for device in devices: + if device["udid"] == udid: + return_value = device + return return_value + + +class ServerIOSLauncher: + """ + Python wrapper for launch iOS RPC to simulator. + + mode : str + Server mode. See RPCServerMode. + + host : str + The tracker/proxy address. + + port : int + The tracker/proxy port. + + key : str + The key used to identify the device type in tracker. + """ + + booted_devices = [] + bundle_id = os.environ.get("BUNDLE_ID") + bundle_path = os.environ.get("BUNDLE_PATH") + + class ConsoleMarkers(Enum): + """ + Marker-messages that iOS RPC Server should print to the console output + when its states change (see apps/ios_rpc/tvmrpc/RPCServer.mm). + + STOPPED : str + iOS RPC Server process was stopped + + CALLSTACK : str + Call stack if RPC Server was stopped with an error. + + CONNECTED : str + RPC Server reports that it successfully connected. + + SERVER_IP : str + IP on which RPC Server started (for standalone mode). + + SERVER_PORT : str + HOST on which RPC Server started (for standalone mode). + """ + + STOPPED = "PROCESS_STOPPED" + CALLSTACK = "First throw call stack" + CONNECTED = "[IOS-RPC] STATE: 2" + SERVER_IP = "[IOS-RPC] IP: " + SERVER_PORT = "[IOS-RPC] PORT: " + + def __init__(self, mode, host, port, key): + if not ServerIOSLauncher.is_compatible_environment(): + raise RuntimeError( + "Can't create ServerIOSLauncher instance." + " No environment variables set for iOS RPC Server." + ) + + self.host = host + self.port = port + + self.external_booted_device = None + if not ServerIOSLauncher.booted_devices: + self._boot_or_find_booted_device() + + self.udid = get_device_uid( + self.external_booted_device + if self.external_booted_device is not None + else ServerIOSLauncher.booted_devices[-1] + ) + + self.bundle_was_deployed = False + deploy_bundle_to_simulator(self.udid, self.bundle_path) + self.bundle_was_deployed = True + + self.server_was_started = False + self.launch_process = launch_ios_rpc(self.udid, self.bundle_id, host, port, key, mode) + self._wait_launch_complete( + waiting_time=60, + hz=10, + should_print_host_and_port=mode == RPCServerMode.standalone.value, + ) + self.server_was_started = True + + def terminate(self): + """Terminate iOS RPC server.""" + + if self.bundle_was_deployed and self.server_was_started: + try: + terminate_ios_rpc(self.udid, self.bundle_id) + self.launch_process.terminate() + self.server_was_started = False + except RuntimeError as e: + print(e) + if self.bundle_was_deployed: + try: + delete_bundle_from_simulator(self.udid, self.bundle_id) + self.bundle_was_deployed = False + except RuntimeError as e: + print(e) + + def __del__(self): + self.terminate() + + @staticmethod + def is_compatible_environment(): + """Check that the current environment has the required variables.""" + + return bool(os.environ.get("BUNDLE_ID")) and bool(os.environ.get("BUNDLE_PATH")) + + @staticmethod + def shutdown_booted_devices(): + """Shutdown simulators that have been booted using this class.""" + + for device_meta in ServerIOSLauncher.booted_devices: + try: + shutdown_device(get_device_uid(device_meta)) + except RuntimeError as e: + print(e) + ServerIOSLauncher.booted_devices = [] + + def _boot_or_find_booted_device(self): + """ + Boot the required simulator if there is no suitable booted simulator + among the available simulators. If there is a suitable booted simulator, + then take it as a simulator to which the iOS RPC application will be deployed. + """ + + target_system = OSName.iOS + target_device_type = IOSDevice.iPhone + available_devices = get_list_of_available_simulators() + if not available_devices: + raise ValueError("No devices available in this environment") + target_devices = grep_by_system(available_devices, target_system) + if not target_devices: + raise ValueError(f"No available simulators for target system: {target_system.value}") + target_devices = grep_by_device(target_devices, target_device_type) + if not target_devices: + raise ValueError( + f"No available simulators for target device type: {target_device_type.value}" + ) + + maybe_booted = check_booted_device(target_devices) + if maybe_booted: + self.external_booted_device = maybe_booted + else: + take_latest_model = True + target_device = target_devices[-1 if take_latest_model else 0] + boot_device(get_device_uid(target_device)) + ServerIOSLauncher.booted_devices.append(target_device) + + def _wait_launch_complete(self, waiting_time, hz, should_print_host_and_port=False): + # pylint: disable=too-many-locals + """ + Wait for the iOS RPC server to start. + + waiting_time : int + The maximum waiting time during which it is necessary + to receive a message from RPC Server. + + hz : int + The frequency of checking (in hertz) messages from RPC Server. + Checks for messages from the server will occur every 1 / hz second. + + should_print_host_and_port : bool + A flag that indicates that RPC Server should print the host and port + on which it was started. + Used for standalone mode. + """ + + class Switch: + """A simple helper class for boolean switching.""" + + def __init__(self): + self._on = False + + def toggle(self): + """Toggle flag.""" + self._on = not self._on + + @property + def on(self): + """Flag of this switch.""" + return self._on + + def watchdog(): + for _ in range(waiting_time * hz): + time.sleep(1.0 / hz) + if switch_have_data.on: + break + if not switch_have_data.on: + self.launch_process.terminate() + switch_process_was_terminated.toggle() + + switch_have_data = Switch() + switch_process_was_terminated = Switch() + watchdog_thread = threading.Thread(target=watchdog) + + host, port = None, None + watchdog_thread.start() + for line in self.launch_process.stdout: + if not switch_have_data.on: + switch_have_data.toggle() + + found = str(line).find(ServerIOSLauncher.ConsoleMarkers.STOPPED.value) + if found != -1: + raise RuntimeError("[ERROR] Crash during RCP Server launch.. ") + + found = str(line).find(ServerIOSLauncher.ConsoleMarkers.CALLSTACK.value) + if found != -1: + raise RuntimeError("[ERROR] Crash during RCP Server launch.. ") + + found = str(line).find(ServerIOSLauncher.ConsoleMarkers.SERVER_IP.value) + if found != -1: + ip = str(line)[ + found + len(ServerIOSLauncher.ConsoleMarkers.SERVER_IP.value) : + ].rstrip("\n") + host = ip + + found = str(line).find(ServerIOSLauncher.ConsoleMarkers.SERVER_PORT.value) + if found != -1: + port = str(line)[ + found + len(ServerIOSLauncher.ConsoleMarkers.SERVER_PORT.value) : + ].rstrip("\n") + port = int(port) + + if str(line).find(ServerIOSLauncher.ConsoleMarkers.CONNECTED.value) != -1: + # rpc server reports that it successfully connected + break + watchdog_thread.join() + + if switch_process_was_terminated.on: + raise TimeoutError("Can't get a response from the iOS Server.") + if should_print_host_and_port: + if host is None or port is None: + raise RuntimeError("No messages with actual host and port.") + self.port = port + + +class ServerIOSContextManager: + """ + Context manager for ServerIOSLauncher. + To work with ServerIOSLauncher, it is preferable to use this class + so that the terminate method is called in any case. + """ + + def __init__(self, mode, host, port, key): + self.__mode = mode + self.__host = host + self.__port = port + self.__key = key + self.__ios_rpc_server_launcher = None + + def __enter__(self): + self.__ios_rpc_server_launcher = ServerIOSLauncher( + self.__mode, self.__host, self.__port, self.__key + ) + return self.__ios_rpc_server_launcher + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.__ios_rpc_server_launcher is not None: + self.__ios_rpc_server_launcher.terminate() + self.__ios_rpc_server_launcher = None diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index b91fe727698b..7d40a81e498a 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -47,6 +47,35 @@ def csv(self): """ return _ffi_api.AsCSV(self) + def table(self, sort=True, aggregate=True, col_sums=True): + """Generate a human-readable table + + Parameters + ---------- + sort : bool + + If aggregate is true, whether to sort call frames by + descending duration. If aggregate is False, whether to + sort frames by order of appearancei n the program. + + aggregate : bool + + Whether to join multiple calls to the same op into a + single line. + + col_sums : bool + + Whether to include the sum of each column. + + Returns + ------- + table : str + + A human-readable table + + """ + return _ffi_api.AsTable(self, sort, aggregate, col_sums) + def json(self): """Convert this profiling report into JSON format. diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 8ebb0f6301d2..c1cbc966acdc 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -71,6 +71,7 @@ def __init__(self, mod): self._save = self.mod["save"] self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] + self._get_constants = self.mod["get_constants"] self._get_stats = self.mod["get_stats"] self._get_function_arity = self.mod["get_function_arity"] self._get_function_param_name = self.mod["get_function_param_name"] @@ -244,6 +245,12 @@ def bytecode(self): """ return self._get_bytecode() + @property + def constants(self): + """Returns a human-readable description of all the constants in the executable. + Useful for debugging and diffing generated executables in unit tests.""" + return self._get_constants() + @property def globals(self): """Get the globals used by the Relay VM executable. diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 8610d91e9f07..080aa0476bec 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -490,6 +490,31 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: self.context.exit_scope() return func + def transform_Lambda(self, node): + """Lambda visitor + + Return an array of input parameters and the transformed lambda body. + """ + + self.context.enter_scope(nodes=[node.body]) + + # add parameters of the lambda + arg_vars = [] + for arg in node.params: + arg_var = tvm.te.var(arg.name) + arg_vars.append(arg_var) + self.context.update_symbol(arg.name, arg_var, node) + + # the body of a lambda must be an expr + if not isinstance(node.body, ast.Expr): + self.report_error("The body of a lambda must be an expression", node.span) + + # transform the body of the lambda + body = self.transform(node.body) + + self.context.exit_scope() + return arg_vars, body + def transform_Assign(self, node): """Assign visitor AST abstract grammar: diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi new file mode 100644 index 000000000000..fba026d414f6 --- /dev/null +++ b/python/tvm/script/tir/__init__.pyi @@ -0,0 +1,359 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Iterable, + Optional, + Tuple, + Union, + Sequence, + List, + Mapping, + overload, +) +from numbers import Number +import builtins + +from tvm.tir.function import PrimFunc +from tvm.tir import Range +from tvm.runtime import Object +from .node import BufferSlice + +""" +redefine types +""" + +class PrimExpr: + def __init__(self: PrimExpr) -> None: ... + @overload + def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... + @overload + def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... + +class Var(PrimExpr): ... +class IterVar(Var): ... + +class Buffer: + @overload + def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ... + @overload + def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ... + @overload + def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ... + @overload + def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ... + @property + def data(self: Buffer) -> Ptr: ... + +""" +Variables and constants +""" + +def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ... +def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ... +def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ... + +""" +Intrinsic +""" + +def min_value(dtype: str) -> PrimExpr: ... +def max_value(dtype: str) -> PrimExpr: ... +def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... +def abs(x: PrimExpr) -> PrimExpr: ... +def load( + dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None +) -> PrimExpr: ... +def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... +def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... +def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... +def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ... +def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... +def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... +def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... +def evaluate(value: PrimExpr) -> None: ... +def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... +def store( + var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True +) -> None: ... +def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ... + +""" +Unary operator +""" + +def exp2(x: PrimExpr) -> PrimExpr: ... +def exp10(x: PrimExpr) -> PrimExpr: ... +def erf(x: PrimExpr) -> PrimExpr: ... +def tanh(x: PrimExpr) -> PrimExpr: ... +def sigmoid(x: PrimExpr) -> PrimExpr: ... +def log(x: PrimExpr) -> PrimExpr: ... +def log2(x: PrimExpr) -> PrimExpr: ... +def log10(x: PrimExpr) -> PrimExpr: ... +def log1p(x: PrimExpr) -> PrimExpr: ... +def tan(x: PrimExpr) -> PrimExpr: ... +def cos(x: PrimExpr) -> PrimExpr: ... +def cosh(x: PrimExpr) -> PrimExpr: ... +def acos(x: PrimExpr) -> PrimExpr: ... +def acosh(x: PrimExpr) -> PrimExpr: ... +def sin(x: PrimExpr) -> PrimExpr: ... +def sinh(x: PrimExpr) -> PrimExpr: ... +def asin(x: PrimExpr) -> PrimExpr: ... +def asinh(x: PrimExpr) -> PrimExpr: ... +def atan(x: PrimExpr) -> PrimExpr: ... +def atanh(x: PrimExpr) -> PrimExpr: ... +def atan2(x: PrimExpr) -> PrimExpr: ... +def sqrt(x: PrimExpr) -> PrimExpr: ... +def rsqrt(x: PrimExpr) -> PrimExpr: ... + +""" +special_stmt - Buffers +""" + +def match_buffer( + param: Union[Var, BufferSlice], + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... +def buffer_decl( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... +def alloc_buffer( + shape: Sequence[Union[PrimExpr, int]], + dtype: str = "float32", + data: Var = None, + strides: Optional[Sequence[int]] = None, + elem_offset: Optional[int] = None, + scope: str = "global", + align: int = -1, + offset_factor: int = 0, + buffer_type: str = "default", +) -> Buffer: ... + +""" +special_stmt - Reads/Writes +""" + +def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def block_attr(attrs: Mapping[str, Object]) -> None: ... + +""" +special_stmt - Axis +""" + +class axis: + @overload + @staticmethod + def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def spatial( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def reduce( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def scan( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @overload + @staticmethod + def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... + @overload + @staticmethod + def opaque( + dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr + ) -> IterVar: ... + @staticmethod + def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... + +def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ... + +""" +special_stmt - Annotations +""" + +def buffer_var(dtype: str, storage_scope: str) -> Var: ... +def func_attr(attrs: Mapping[str, Object]) -> None: ... +def prim_func(input_func: Callable) -> PrimFunc: ... + +""" +special_stmt - Threads and Bindings +""" + +def env_thread(env_name: str) -> IterVar: ... +def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... + +""" +Scope handler +""" + +class block(ContextManager): + def __init__(self, name_hint: str = "") -> None: ... + def __enter__(self) -> Sequence[IterVar]: ... + +class init(ContextManager): + def __init__(self) -> None: ... + +class let(ContextManager): + def __init__(self, var: Var, value: PrimExpr) -> None: ... + +def where(cond: PrimExpr) -> None: ... +def allocate( + extents: List[PrimExpr], + dtype: str, + scope: str, + condition: Union[PrimExpr, builtins.bool] = True, + annotations: Optional[Mapping[str, Object]] = None, +) -> Var: ... +def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... +def realize( + buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True +) -> None: ... +def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... +def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ... + +""" +Scope handler - Loops +""" + +def serial( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def parallel( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def vectorized( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def unroll( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def thread_binding( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int], + thread: str, + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def for_range( + begin: Union[PrimExpr, int], + end: Union[PrimExpr, int] = None, + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ... + +""" +ty - redefine types +""" + +class boolean: ... + +class handle: + @overload + def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ... + @overload + def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ... + @overload + def __setitem__( + self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer + ) -> None: ... + @overload + def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ... + @property + def data(self: handle) -> Ptr: ... + +class Ptr: ... diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 4d7fe80b28b1..d31e93c72b15 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -16,6 +16,7 @@ # under the License. """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level +import builtins from typing import List, Any import tvm.tir @@ -120,6 +121,11 @@ def floormod(x, y, span): return tvm.tir.floormod(x, y, span) +@register +def truncmod(x, y, span): + return tvm.tir.truncmod(x, y, span) + + @register def abs(x, span): return tvm.tir.abs(x, span) @@ -211,3 +217,20 @@ def store(var, index, value, predicate=True, span=None): return tvm.tir.Store(var, value, index, predicate, span) super().__init__(store, stmt=True) + + +@register +def comm_reducer(lambda_io, identities, span): + """Create a CommReducer from lambda inputs/outputs and the identities""" + lambda_input = lambda_io[0] + lambda_output = lambda_io[1] + + num_args = len(lambda_input) + num_arg_per_group = num_args // 2 + x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)] + y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)] + + if not isinstance(lambda_output, tuple): + lambda_output = (lambda_output,) + + return tvm.tir.CommReducer(x, y, lambda_output, identities, span) diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 6a4f7bc00cb6..9140310d4733 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -30,8 +30,13 @@ def evaluate(self): """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") + # This function is added here to avoid a pylint error + # for T.int/float below not being callable + def __call__(self): + raise NotImplementedError() -class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods + +class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method """TVM script typing class for uniform Type objects""" def __init__(self, vtype): @@ -41,7 +46,7 @@ def evaluate(self): return tvm.ir.PrimType(self.type) -class GenericPtrType(TypeGeneric): +class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for PtrType [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType @@ -51,7 +56,7 @@ def __getitem__(self, vtype): return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) -class GenericTupleType(TypeGeneric): +class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method """TVM script typing class generator for TupleType [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1e906cb381d8..5dc95c3ae675 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,6 +71,8 @@ riscv_cpu, hexagon, ) +from .se_scope import make_se_scope +from .compilation_config import make_compilation_config from .tag import list_tags from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func diff --git a/python/tvm/target/compilation_config.py b/python/tvm/target/compilation_config.py new file mode 100644 index 000000000000..2796ec4b5135 --- /dev/null +++ b/python/tvm/target/compilation_config.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating CompilationConfigs.""" +from . import _ffi_api + + +def make_compilation_config(ctxt, targets, host_target=None): + """Returns a CompilationConfig appropriate for targets and an optional host_target. + Currently intended just for unit tests and will be replaced by a Python CompilationConfig + class in the future. Note that targets must be a dictionary from IntImm objects to Targets + and we do not support any of the lighter-weight conventions used by the various build(...) + APIs.""" + return _ffi_api.MakeCompilationConfig(ctxt, targets, host_target) diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/se_scope.py new file mode 100644 index 000000000000..83df5ae3448a --- /dev/null +++ b/python/tvm/target/se_scope.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Python bindings for creating SEScopes.""" +from . import _ffi_api + + +def make_se_scope(device, target=None, memory_scope=""): + return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 250c165caf9a..308257085e51 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -33,7 +33,7 @@ from .tag import tag_scope from .operation import placeholder, compute, scan, extern, var, size_var from .operation import thread_axis, reduce_axis -from .operation import create_prim_func +from .operation import create_prim_func, create_prim_func_from_outputs from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index cb0305d49e4a..5cb58a85ed10 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -17,14 +17,14 @@ """ Operation class for computation declaration.""" # pylint: disable=invalid-name from numbers import Integral as _Integral -from typing import List +from typing import List, Union import tvm._ffi -import tvm.tir -import tvm.tir._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert +import tvm.tir +import tvm.tir._ffi_api from . import _ffi_api from . import tag as _tag @@ -482,3 +482,23 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: if not isinstance(ops, (list, tuple, Array)): ops = [ops] return _ffi_api.CreatePrimFunc(ops) + + +def create_prim_func_from_outputs( + outputs: Union[_tensor.Tensor, List[_tensor.Tensor]], +) -> tvm.tir.PrimFunc: + """Create a TensorIR PrimFunc from output tensor(s) in TE + + Parameters + ---------- + outputs : Union[Tensor, List[Tensor]] + The source expression. + + Returns + ------- + func : tir.PrimFunc + The created function. + """ + if not isinstance(outputs, (list, tuple, Array)): + outputs = [outputs] + return _ffi_api.CreatePrimFuncFromOutputs(outputs) diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index 2cb228c357e5..c0decb7747bd 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -253,7 +253,13 @@ def _sort_tests(items): Should be called from pytest_collection_modifyitems. """ - items.sort(key=lambda item: item.location) + + def sort_key(item): + filename, lineno, test_name = item.location + test_name = test_name.split("[")[0] + return filename, lineno, test_name + + items.sort(key=sort_key) def _target_to_requirement(target): diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 44006239acfd..428403a98f16 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -25,7 +25,7 @@ from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle -from .expr import Call, CallEffectKind, Let, IterVar, Any +from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 2bfa0aacb184..27cf5351a077 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -442,7 +442,7 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): @tvm._ffi.register_object("tir.CommReducer") class CommReducer(Object): - """Communicative reduce operator + """Commutative reduce operator Parameters ---------- diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index b002ace0e400..ecbcd837cb72 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -143,7 +143,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: """ return _ffi_api.Specialize(self, param_map) # type: ignore - def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str: + def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript Parameters diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 786982cf704c..884eeb7c612c 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -325,6 +325,39 @@ def sample_categorical( decision, ) + def sample_perfect_tile( + self, + loop: LoopRV, + n: int, + max_innermost_factor: int = 16, + decision: Optional[List[int]] = None, + ) -> List[ExprRV]: + """Sample the factors to perfect tile a specific loop + + Parameters + ---------- + loop : LoopRV + The loop to be tiled + n : int + The number of tiles to be sampled + max_innermost_factor : int + The maximum tile size allowed to be sampled in the innermost loop + decision: Optional[List[int]] + The sampling decision, if any + + Returns + ------- + result : List[ExprRV] + A list of length `n`, the random perfect tile sizes sampled + """ + return _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member + self, + loop, + n, + max_innermost_factor, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, @@ -367,6 +400,51 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: """ return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]: + """Get the leaf blocks of a specific block/loop + + Parameters + ---------- + block_or_loop : Union[BlockRV, LoopRV] + The query block/loop + + Returns + ------- + blocks : List[LoopRV] + A list of leaf blocks inside a specific block/loop + """ + return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type: ignore # pylint: disable=no-member + + def get_producers(self, block: BlockRV) -> List[BlockRV]: + """Get the producers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + producers : List[BlockRV] + A list of producers of the given block + """ + return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member + + def get_consumers(self, block: BlockRV) -> List[BlockRV]: + """Get the consumers of a specific block + + Parameters + ---------- + block : BlockRV + The block in the query + + Returns + ------- + consumers : List[BlockRV] + A list of consumers of the given block + """ + return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member + ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 66ede31f4103..04cbffcd4d87 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -17,18 +17,17 @@ """Testing utilities for the TensorIR schedule API""" from typing import Union -from tvm import tir from tvm.ir import IRModule, structural_equal from tvm.tir import PrimFunc -from tvm.tir.schedule import Trace +from tvm.tir.schedule import Trace, Schedule def verify_trace_roundtrip( - sch: tir.Schedule, + sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", -) -> tir.Schedule: +) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. @@ -51,7 +50,7 @@ def verify_trace_roundtrip( assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling - new_sch = tir.Schedule(mod=mod, debug_mask=debug_mask) + new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index c7c572c81110..cbe8644c885f 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -90,7 +90,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/arm_cpu/injective.py b/python/tvm/topi/arm_cpu/injective.py index 55f47c5dee4d..330144b33fb6 100644 --- a/python/tvm/topi/arm_cpu/injective.py +++ b/python/tvm/topi/arm_cpu/injective.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-variable """Schedule for pooling operators""" import tvm +import numpy as np from tvm import te from ..utils import is_empty_shape @@ -67,7 +68,7 @@ def schedule_injective(outs): if list(s[x].op.axis): # do not vectorize for broadcast - (io, ii) = s[x].split(list(s[x].op.axis)[-1], 4) + (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(x.dtype).itemsize) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/python/tvm/topi/bifrost/conv2d.py b/python/tvm/topi/bifrost/conv2d.py index 3b6cca6aaea4..633f36c0e7ff 100644 --- a/python/tvm/topi/bifrost/conv2d.py +++ b/python/tvm/topi/bifrost/conv2d.py @@ -477,7 +477,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index 4863a06b728d..3d05058ff52c 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -46,7 +46,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/cuda/conv3d_alter_op.py b/python/tvm/topi/cuda/conv3d_alter_op.py index faf73e77255a..c7ec7cb21fcf 100644 --- a/python/tvm/topi/cuda/conv3d_alter_op.py +++ b/python/tvm/topi/cuda/conv3d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv3d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) dispatch_ctx = autotvm.task.DispatchContext.current - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/intel_graphics/conv2d_alter_op.py b/python/tvm/topi/intel_graphics/conv2d_alter_op.py index 0b59a849c2c9..199d984af1e4 100644 --- a/python/tvm/topi/intel_graphics/conv2d_alter_op.py +++ b/python/tvm/topi/intel_graphics/conv2d_alter_op.py @@ -35,7 +35,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - _, outs = relay.backend.compile_engine.select_implementation( + _, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index f3ef55b9a30c..051914113a5b 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -531,7 +531,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 8e47dff37ce6..3f2df655a615 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -57,7 +57,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): cfg = dispatch_ctx.query(target, None) workload = cfg.workload else: - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 8db84497f82d..1d64261a50d7 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -35,7 +35,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - impl, outs = relay.backend.compile_engine.select_implementation( + impl, outs = relay.backend.te_compiler.select_implementation( relay.op.get("nn.dense"), attrs, tinfos, out_type, target ) workload = autotvm.task.get_workload(outs) diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 7793f9f6962e..80c7efbaf894 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -224,10 +224,9 @@ fn main() -> Result<()> { }?; // If the TVM_HOME environment variable changed, the LLVM_CONFIG_PATH environment variable - // changed, the build directory or headers have changed we need to rebuild the Rust bindings. + // changed or the source headers have changed we need to rebuild the Rust bindings. println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rerun-if-env-changed=LLVM_CONFIG_PATH"); - println!("cargo:rerun-if-changed={}", build_path.display()); println!("cargo:rerun-if-changed={}/include", source_path.display()); let library_name = if cfg!(feature = "runtime-only") { diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index f43967f28d60..b65b784bf400 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -163,7 +163,7 @@ impl Call { span: Span, ) -> Call { let node = CallNode { - base: ExprNode::base::(span), + base: ExprNode::base::(span), op: op, args: args, attrs: attrs, diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a402212cf4ea..fe3a37f88fa4 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -71,6 +71,8 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + if (a->IsEmpty()) return b; + if (b->IsEmpty()) return a; PrimExpr max_value = max(a->max_value, b->max_value); PrimExpr min_value = min(a->min_value, b->min_value); return IntervalSet(min_value, max_value); diff --git a/src/contrib/torch/pt_call_tvm/tvm_class.cc b/src/contrib/torch/pt_call_tvm/tvm_class.cc new file mode 100644 index 000000000000..5e57dc152f11 --- /dev/null +++ b/src/contrib/torch/pt_call_tvm/tvm_class.cc @@ -0,0 +1,686 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace contrib { +namespace pytorch { + +/*! \brief Class holding necessary components to call TVM graph runtime */ +class TvmGraphModulePack { + public: + /*! + * \brief Constructor. + * + * \param path Encoded path of graph runtime assets. + * \param device_type int64_t, kDLCPU or kDLCUDA. + * \param device_id int64_t. + */ + explicit TvmGraphModulePack(std::string path, int64_t device_type, int64_t device_id) + : path_(std::move(path)) { + LOG(INFO) << "[TvmGraphModule] loading module at path: [" << path_ << "] on device [" + << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; + std::string lib_path, graph_path, params_path; + DecodePaths(path_, &lib_path, &graph_path, ¶ms_path); + + // load graph + std::ifstream graph_in(graph_path); + std::string graph_data((std::istreambuf_iterator(graph_in)), + std::istreambuf_iterator()); + graph_in.close(); + + // load mod syslib + tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(lib_path); + + const auto runtime_create = *tvm::runtime::Registry::Get("tvm.graph_executor.create"); + + // read params data + std::ifstream params_in(params_path, std::ios::binary); + std::string params_data((std::istreambuf_iterator(params_in)), + std::istreambuf_iterator()); + params_in.close(); + TVMByteArray params_arr; + params_arr.data = params_data.c_str(); + params_arr.size = params_data.length(); + + // set devices + module_ = runtime_create(graph_data, lib, device_type, device_id); + const tvm::runtime::PackedFunc load_params = module_.GetFunction("load_params"); + load_params(params_arr); + + set_input = module_.GetFunction("set_input_zero_copy"); + run = module_.GetFunction("run"); + get_output = module_.GetFunction("get_output"); + set_output = module_.GetFunction("set_output_zero_copy"); + num_outputs_ = module_.GetFunction("get_num_outputs")(); + } + + static constexpr char kPathDelimiter = '|'; + + /*! + * \brief Decode lib_path, graph_path, params_path from encoded path. + * + * \param path The encoded path, concated with `kPathDelimiter`. + * \param lib_path The path of .so lib file. + * \param graph_path The path of graph.json. + * \param params_path The path of params data. + */ + static void DecodePaths(const std::string& path, std::string* lib_path, std::string* graph_path, + std::string* params_path) { + std::vector paths; + for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { + if (i == lim || path.at(i) == kPathDelimiter) { + paths.push_back(path.substr(pre, i - pre)); + pre = i + 1; + } + } + CHECK_EQ(paths.size(), 3u); + *lib_path = paths.at(0); + *graph_path = paths.at(1); + *params_path = paths.at(2); + } + + /*! + * \brief Encode lib_path, graph_path, params_path by concat then with `kPathDelimiter`. + * + * \param lib_path The path of .so lib file. + * \param graph_path The path of graph.json. + * \param params_path The path of params data. + * + * \return The encoded path, concated with `kPathDelimiter`. + */ + static std::string EncodePaths(const std::string& lib_path, const std::string& graph_path, + const std::string& params_path) { + return lib_path + kPathDelimiter + graph_path + kPathDelimiter + params_path; + } + + const std::string& path() const { return path_; } + + const int64_t num_outputs() const { return num_outputs_; } + + tvm::runtime::PackedFunc set_input; + tvm::runtime::PackedFunc run; + tvm::runtime::PackedFunc get_output; + tvm::runtime::PackedFunc set_output; + + private: + tvm::runtime::Module module_; + int64_t num_outputs_; + std::string path_; +}; + +/*! \brief Class holding necessary components to call TVM VM runtime */ +class TvmVMModulePack { + public: + /*! + * \brief Constructor. + * + * \param path Encoded path of vm runtime assets. + * \param device_type int64_t, kDLCPU or kDLCUDA. + * \param device_id int64_t. + */ + explicit TvmVMModulePack(std::string path, int64_t device_type, int64_t device_id) + : path_(std::move(path)) { + LOG(INFO) << "[TvmVMModule] loading module at path: [" << path_ << "] on device [" + << (device_type == kDLCUDA ? "cuda:" : "cpu:") << device_id << "]..."; + // build tvm graph runtime + std::string lib_path, code_path; + DecodePaths(path_, &lib_path, &code_path); + // load lib + auto loaded_lib = tvm::runtime::Module::LoadFromFile(lib_path, "so"); + // load code + std::ifstream code_in(code_path); + std::string loaded_code((std::istreambuf_iterator(code_in)), + std::istreambuf_iterator()); + code_in.close(); + exe_ = tvm::runtime::vm::Executable::Load(loaded_code, loaded_lib); + const auto runtime_create = *tvm::runtime::Registry::Get("runtime._VirtualMachine"); + vm_ = runtime_create(exe_); + auto init_func = vm_.GetFunction("init", false); + auto alloc_type = static_cast(tvm::runtime::vm::AllocatorType::kPooled); + if (device_type != kDLCPU) { + // CPU is required for executing shape functions + init_func(static_cast(kDLCPU), 0, alloc_type, device_type, device_id, alloc_type); + } else { + init_func(device_type, device_id, alloc_type); + } + set_input = vm_.GetFunction("set_input", false); + invoke = vm_.GetFunction("invoke", false); + } + + static constexpr char kPathDelimiter = '|'; + + /*! + * \brief Decode lib_path, code_path from encoded path. + * + * \param path The encoded path, concated with `kPathDelimiter`. + * \param lib_path The path of lib file. + * \param code_path The path of code file. + */ + static void DecodePaths(const std::string& path, std::string* lib_path, std::string* code_path) { + std::vector paths; + for (size_t i = 0, pre = 0, lim = path.size(); i <= lim; ++i) { + if (i == lim || path.at(i) == kPathDelimiter) { + paths.push_back(path.substr(pre, i - pre)); + pre = i + 1; + } + } + CHECK_EQ(paths.size(), 2u); + *lib_path = paths.at(0); + *code_path = paths.at(1); + } + + /*! + * \brief Encode lib_path, code_path by concat then with `kPathDelimiter`. + * + * \param lib_path The path of vm lib file. + * \param code_path The path of code. + * + * \return The encoded path, concated with `kPathDelimiter`. + */ + static std::string EncodePaths(const std::string& lib_path, const std::string& code_path) { + return lib_path + kPathDelimiter + code_path; + } + + const std::string& path() const { return path_; } + + tvm::runtime::PackedFunc set_input; + tvm::runtime::PackedFunc invoke; + + private: + tvm::runtime::Module exe_; + tvm::runtime::Module vm_; + std::string path_; +}; + +/*! \brief Pytorch custom class to call TVM */ +class BaseTvmClass : public torch::jit::CustomClassHolder { + public: + /*! + * \brief Constructor. + * + * \param num_inputs Number of inputs. + * \param num_outputs Number of outputs. + * \param device std::string, use the pytorch device str format, e.g. `cuda:0`, 'cpu' + */ + BaseTvmClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) + : num_inputs_(num_inputs), num_outputs_(num_outputs) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + } + + /*! \brief Virtual destructor. */ + virtual ~BaseTvmClass() {} + + /*! + * \brief Get repr string of pytorch input shapes. + * + * \param shapes Pytorch shapes of type List[List[int]]. + * + * \return std::string, the representation of inputs shapes. + */ + static std::string TvmShapeRepr(const c10::List>& shapes) { + std::stringstream ss; + for (const auto& shape : shapes) { + for (const auto& sz : static_cast>(shape)) { + ss << sz << "_"; + } + ss << "__"; + } + return ss.str(); + } + + /*! + * \brief Get input shapes. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[List[int]]. + */ + static c10::List> GetShapes(const c10::List& inputs) { + c10::List> shapes; + for (const auto& input : inputs) { + c10::List shape; + for (const auto sz : static_cast(input).sizes()) { + shape.push_back(sz); + } + shapes.push_back(shape); + } + return shapes; + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + virtual void to(const std::string& device) = 0; + + // getters + int64_t num_inputs() const { return num_inputs_; } + + int64_t num_outputs() const { return num_outputs_; } + + int64_t device_type() const { return device_type_; } + + int64_t device_id() const { return device_id_; } + + c10::DeviceType torch_device_type() const { + return device_type() == kDLCUDA ? torch::DeviceType::CUDA : torch::DeviceType::CPU; + } + + bool is_on_same_device(const torch::Tensor& tensor) const { + auto tensor_device_type = tensor.device().type(); + if (tensor_device_type == torch::DeviceType::CUDA) { + return tensor_device_type == torch_device_type() && device_id() == tensor.device().index(); + } + CHECK_EQ(tensor_device_type, torch::DeviceType::CPU); + return tensor_device_type == torch_device_type(); + } + + std::string device() const { return torch::Device(torch_device_type(), device_id()).str(); } + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + virtual c10::List forward(const c10::List& inputs) = 0; + + /*! + * \brief Serialize TVM Modules to Dict + */ + virtual c10::Dict SerializeTvmModules() const = 0; + + /*! + * \brief deserialize TVM Modules from Dict + */ + virtual void DeserializeTvmModules(const c10::Dict& shape_path_map) = 0; + + protected: + const int64_t num_inputs_; + const int64_t num_outputs_; + int64_t device_type_; + int64_t device_id_; +}; + +/*! \brief Pytorch custom class to call TVM graph runtime */ +class TvmGraphRuntimeClass : public BaseTvmClass { + public: + TvmGraphRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, + const std::string& device) + : BaseTvmClass(num_inputs, num_outputs, device) {} + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + c10::List forward(const c10::List& inputs) override { + CHECK_EQ(inputs.size(), num_inputs_); + auto shape_repr = TvmShapeRepr(GetShapes(inputs)); + std::vector args(num_inputs_ + num_outputs_); + auto iter = tvm_modules_.find(shape_repr); + CHECK(iter != tvm_modules_.end()); + const auto& tvm_pack = iter->second; + std::vector buf_infos; + buf_infos.reserve(num_inputs_ + num_outputs_); + + for (int i = 0; i < num_inputs_; ++i) { + at::Tensor inp = inputs[i]; + CHECK(is_on_same_device(inp)) + << "input #" << i + << " of forward is not on the same device with TvmGraphRuntime, expected " << device() + << " but got " << inp.device().str(); + inp = inp.contiguous(); + buf_infos.emplace_back(inp); + auto& input_buf = buf_infos[i]; + input_buf.CopyFromOrigin(); + input_buf.MakeDLTensor(&args[i]); + tvm_pack.set_input(i, &args[i]); + } + // prepare output buffers + c10::List outputs; + outputs.reserve(num_outputs_); + + for (int i = 0; i < num_outputs_; ++i) { + tvm::runtime::NDArray output_arr = tvm_pack.get_output(i); + std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); + + torch::ScalarType output_dtype = torch::ScalarType::Undefined; + CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); + + CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); + const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); + const auto options = + torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); + + outputs.emplace_back(torch::empty(output_shape, options)); + buf_infos.emplace_back(outputs[i]); + auto& output_buf = buf_infos[num_inputs_ + i]; + output_buf.MakeDLTensor(&args[num_inputs_ + i]); + tvm_pack.set_output(i, &args[num_inputs_ + i]); + } + tvm_pack.run(); + for (int i = 0; i < num_outputs_; ++i) { + auto& output_buf = buf_infos[num_inputs_ + i]; + output_buf.CopyToOrigin(); + } + return outputs; + } + + /*! + * \brief Load TVM graph runtime module. + * + * \param shapes Input shapes. List[List[int]]. + * \param lib_path Path of .so lib file. + * \param graph_path Path of graph.json file. + * \param params_path Path of params data file. + */ + void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, + const std::string& graph_path, const std::string& params_path) { + std::string path = TvmGraphModulePack::EncodePaths(lib_path, graph_path, params_path); + auto shape_repr = TvmShapeRepr(shapes); + auto it_find = tvm_modules_.find(shape_repr); + if (it_find != tvm_modules_.end()) { + tvm_modules_.erase(it_find); + } + const auto it = + tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)).first; + if (it->second.num_outputs() != num_outputs_) { + LOG(FATAL) << "tvm class num outputs mismatch, expected " << num_outputs_ << ", got " + << it->second.num_outputs(); + } + } + + const std::map& tvm_modules() const { return tvm_modules_; } + + /*! + * \brief Serialize TVM modules to shape map. + * + * \return shape_path_map Dict of shape_repr to path. + */ + c10::Dict SerializeTvmModules() const override { + c10::Dict shape_path_map; + for (const auto& entry : tvm_modules()) { + shape_path_map.insert(entry.first, entry.second.path()); + } + return shape_path_map; + } + + /*! + * \brief Deserialize TVM modules from shape map. + * + * \param shape_path_map Dict of shape_repr to path. + */ + void DeserializeTvmModules(const c10::Dict& shape_path_map) override { + tvm_modules_.clear(); + for (const auto& entry : shape_path_map) { + const auto& shape_repr = entry.key(); + const auto& path = entry.value(); + tvm_modules_.emplace(shape_repr, TvmGraphModulePack(path, device_type_, device_id_)); + } + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + void to(const std::string& device) override { + if (device != this->device()) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + DeserializeTvmModules(SerializeTvmModules()); + } + } + + private: + std::map tvm_modules_; +}; + +/*! \brief Pytorch custom class to call TVM graph runtime */ +class TvmVMRuntimeClass : public BaseTvmClass { + public: + TvmVMRuntimeClass(const int64_t num_inputs, const int64_t num_outputs, const std::string& device) + : BaseTvmClass(num_inputs, num_outputs, device) {} + + /*! + * \brief Module forward. + * + * \param inputs Inputs with type List[Tensor]. + * + * \return outputs with type List[Tensor]. + */ + c10::List forward(const c10::List& inputs) override { + // get inputs repr str + auto shape_repr = TvmShapeRepr(GetShapes(inputs)); + // get tvm pack + auto iter = tvm_modules_.find(shape_repr); + CHECK(iter != tvm_modules_.end()) << "tvm module pack not found for shape_repr " << shape_repr; + const auto& tvm_pack = iter->second; + + // input tensors + CHECK_EQ(inputs.size(), num_inputs_); + std::vector args(num_inputs_); + std::vector args_arr(num_inputs_); + + for (int i = 0; i < num_inputs_; ++i) { + TensorAsBuf input_buf(inputs[i]); + input_buf.CopyFromOrigin(); + input_buf.MakeDLTensor(&args[i]); + args_arr[i] = + tvm::runtime::NDArray::FromDLPack(new DLManagedTensor({args[i], nullptr, nullptr})); + } + // set input + std::vector tvm_values(num_inputs_ + 1); + std::vector tvm_type_codes(num_inputs_ + 1); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + setter(0, "main"); + for (int k = 0; k < num_inputs_; ++k) { + setter(k + 1, args_arr[k]); + } + tvm_pack.set_input.CallPacked( + tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_inputs_ + 1), nullptr); + + // run tvm + tvm::runtime::TVMRetValue ret = tvm_pack.invoke("main"); + + // get outputs + std::vector output_arrs(num_outputs_); + auto output_mismatch_msg = [](int actual, int expected) { + std::stringstream ss; + ss << "num_outputs not equal, actual:[" << actual << "] != expected:[" << expected << "]"; + return ss.str(); + }; + if (ret.type_code() == kTVMNDArrayHandle) { + CHECK_EQ(num_outputs_, 1) << output_mismatch_msg(1, num_outputs_); + output_arrs.at(0) = ret.AsObjectRef(); + } else if (ret.type_code() == kTVMObjectHandle) { + const auto& adt = ret.AsObjectRef(); + CHECK_EQ(adt.size(), num_outputs_) << output_mismatch_msg(adt.size(), num_outputs_); + for (size_t i = 0; i < adt.size(); ++i) { + CHECK(adt[i]->IsInstance()) + << "adt elements not tvm::runtime::NDArray"; + output_arrs.at(i) = tvm::runtime::Downcast(adt[i]); + } + } else { + LOG(FATAL) << "unsupported return type with type_code = " << ret.type_code(); + } + + std::vector output_args(num_outputs_); + c10::List outputs; + outputs.reserve(num_outputs_); + + for (int i = 0; i < num_outputs_; ++i) { + const auto& output_arr = output_arrs[i]; + std::vector output_shape(output_arr->shape, output_arr->shape + output_arr->ndim); + + torch::ScalarType output_dtype = torch::ScalarType::Undefined; + CHECK(GetTorchDtype(output_arr.DataType(), &output_dtype)); + + CHECK(device_type_ == kDLCPU || device_type_ == kDLCUDA); + const c10::DeviceType pt_device_type = (device_type_ == kDLCUDA ? torch::kCUDA : torch::kCPU); + const auto options = + torch::TensorOptions().dtype(output_dtype).device(pt_device_type, device_id_); + + outputs.emplace_back(torch::empty(output_shape, options)); + TensorAsBuf output_buf(outputs[i]); + output_buf.MakeDLTensor(&output_args[i]); + output_arr.CopyTo(&output_args[i]); + output_buf.CopyToOrigin(); + } + return outputs; + } + + /*! + * \brief Load TVM vm runtime module. + * + * \param shapes Input shapes. List[List[int]]. + * \param lib_path Path of .so lib file. + * \param code_path Path of code file. Typically named code.ro + */ + void LoadTvmModule(const c10::List>& shapes, const std::string& lib_path, + const std::string& code_path) { + std::string path = TvmVMModulePack::EncodePaths(lib_path, code_path); + auto shape_repr = TvmShapeRepr(shapes); + auto it_find = tvm_modules_.find(shape_repr); + if (it_find != tvm_modules_.end()) { + tvm_modules_.erase(it_find); + } + tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); + } + + const std::map& tvm_modules() const { return tvm_modules_; } + + /*! + * \brief Serialize TVM modules to shape map. + * + * \return shape_path_map Dict of shape_repr to path. + */ + c10::Dict SerializeTvmModules() const override { + c10::Dict shape_path_map; + for (const auto& entry : tvm_modules()) { + shape_path_map.insert(entry.first, entry.second.path()); + } + return shape_path_map; + } + + /*! + * \brief Deserialize TVM modules from shape map. + * + * \param shape_path_map Dict of shape_repr to path. + */ + void DeserializeTvmModules(const c10::Dict& shape_path_map) override { + tvm_modules_.clear(); + for (const auto& entry : shape_path_map) { + const auto& shape_repr = entry.key(); + const auto& path = entry.value(); + tvm_modules_.emplace(shape_repr, TvmVMModulePack(path, device_type_, device_id_)); + } + } + + /*! + * \brief Move the TVM modules to given device. + * + * \param device String repr of the device to be moved to. + */ + void to(const std::string& device) override { + if (device != this->device()) { + auto torch_device = torch::Device(device); + device_type_ = torch_device.is_cuda() ? kDLCUDA : kDLCPU; + device_id_ = torch_device.index(); + DeserializeTvmModules(SerializeTvmModules()); + } + } + + private: + std::map tvm_modules_; +}; + +// +using SerializeTuple = + std::tuple>; + +/***** registries *****/ +static auto __tvm_dsoop_graph_runtime_registry = + torch::jit::class_("tvm_dsoop", "TvmGraphModule") + .def(torch::init()) + .def("load_tvm_module", &TvmGraphRuntimeClass::LoadTvmModule) + .def("forward", &TvmGraphRuntimeClass::forward) + .def("to", &TvmGraphRuntimeClass::to) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializeTuple { + return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), + self->SerializeTvmModules()); + }, + [](SerializeTuple tuple) -> c10::intrusive_ptr { + auto ptr = c10::make_intrusive( + /*num_inputs=*/std::get<0>(tuple), + /*num_outputs=*/std::get<1>(tuple), + /*device=*/std::get<2>(tuple)); + ptr->DeserializeTvmModules(std::get<3>(tuple)); + return ptr; + }); + +static auto __tvm_dsoop_vm_runtime_registry = + torch::jit::class_("tvm_dsoop", "TvmVMModule") + .def(torch::init()) + .def("load_tvm_module", &TvmVMRuntimeClass::LoadTvmModule) + .def("forward", &TvmVMRuntimeClass::forward) + .def("to", &TvmVMRuntimeClass::to) + .def_pickle( + [](const c10::intrusive_ptr& self) -> SerializeTuple { + return std::make_tuple(self->num_inputs(), self->num_outputs(), self->device(), + self->SerializeTvmModules()); + }, + [](SerializeTuple tuple) -> c10::intrusive_ptr { + auto ptr = c10::make_intrusive( + /*num_inputs=*/std::get<0>(tuple), + /*num_outputs=*/std::get<1>(tuple), + /*device=*/std::get<2>(tuple)); + ptr->DeserializeTvmModules(std::get<3>(tuple)); + return ptr; + }); + +static auto __tvm_shape_repr_fn_registry = + torch::RegisterOperators("tvm_dsoop::tvm_shape_repr", &BaseTvmClass::TvmShapeRepr); +} // namespace pytorch +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/torch/utils.h b/src/contrib/torch/utils.h new file mode 100644 index 000000000000..a98e058ca346 --- /dev/null +++ b/src/contrib/torch/utils.h @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utils.h + * \brief Util functions for pytorch tvm interaction. + */ + +#ifndef TVM_CONTRIB_TORCH_UTILS_H_ +#define TVM_CONTRIB_TORCH_UTILS_H_ + +#include +#include +#include +#include +#ifdef PT_TVMDSOOP_ENABLE_GPU +#include +#endif + +#include +#include + +namespace tvm { +namespace contrib { +namespace pytorch { + +inline bool GetTvmDtype(const caffe2::TypeMeta& dtype, DLDataType* res) noexcept { + if (dtype == torch::kFloat16) { + *res = {kDLFloat, 16, 1}; + } else if (dtype == torch::kFloat32) { + *res = {kDLFloat, 32, 1}; + } else if (dtype == torch::kFloat64) { + *res = {kDLFloat, 64, 1}; + } else if (dtype == torch::kInt8) { + *res = {kDLInt, 8, 1}; + } else if (dtype == torch::kInt16) { + *res = {kDLInt, 16, 1}; + } else if (dtype == torch::kInt32) { + *res = {kDLInt, 32, 1}; + } else if (dtype == torch::kInt64) { + *res = {kDLInt, 64, 1}; + } else if (dtype == torch::kUInt8) { + *res = {kDLUInt, 8, 1}; + } else if (dtype == torch::kBool) { + *res = {kDLInt, 1, 1}; + } else { + return false; + } + return true; +} + +inline bool GetTvmDtype(const caffe2::TypeMeta& dtype, tvm::runtime::DataType* res) noexcept { + DLDataType dlpack_dtype; + + if (!GetTvmDtype(dtype, &dlpack_dtype)) { + return false; + } + *res = tvm::runtime::DataType(dlpack_dtype); + return true; +} + +inline bool GetTorchDtype(const DLDataType& dtype, c10::ScalarType* res) noexcept { + if (dtype.lanes != 1) { + // only scalar type + return false; + } + if (dtype.code == kDLFloat) { + if (dtype.bits == 16) { + *res = torch::kFloat16; + } else if (dtype.bits == 32) { + *res = torch::kFloat32; + } else if (dtype.bits == 64) { + *res = torch::kFloat64; + } else { + return false; + } + } else if (dtype.code == kDLInt) { + if (dtype.bits == 16) { + *res = torch::kInt16; + } else if (dtype.bits == 32) { + *res = torch::kInt32; + } else if (dtype.bits == 64) { + *res = torch::kInt64; + } else if (dtype.bits == 1) { + *res = torch::kBool; + } else { + return false; + } + } else if (dtype.code == kDLUInt) { + if (dtype.bits == 8) { + *res = torch::kUInt8; + } else if (dtype.bits == 1) { + *res = torch::kBool; + } else { + return false; + } + } else { + return false; + } + return true; +} + +inline bool GetTorchDtype(const tvm::runtime::DataType& dtype, c10::ScalarType* res) noexcept { + using tvm::runtime::DataType; + if (dtype == DataType::Float(16)) { + *res = torch::kFloat16; + } else if (dtype == DataType::Float(32)) { + *res = torch::kFloat32; + } else if (dtype == DataType::Float(64)) { + *res = torch::kFloat64; + } else if (dtype == DataType::Int(32)) { + *res = torch::kInt32; + } else if (dtype == DataType::Int(64)) { + *res = torch::kInt64; + } else if (dtype == DataType::Int(1)) { + *res = torch::kBool; + } else if (dtype == DataType::Int(8)) { + *res = torch::kInt8; + } else if (dtype == DataType::Int(16)) { + *res = torch::kInt16; + } else if (dtype == DataType::UInt(8)) { + *res = torch::kUInt8; + } else if (dtype == DataType::Bool()) { + *res = torch::kBool; + } else { + return false; + } + return true; +} + +// Buffer information used for actual computation. +// Each buffer is associated with one PyTorch tensor +// whose underlying buffer is record into "origin_buf". +// For input tensor, we copy data from origin_buf to buf +// and for output tensor, copy data from buf to origin_buf +class TensorAsBuf { + public: + explicit TensorAsBuf(const at::Tensor& tensor) + : pt_device_type_(tensor.device().type()), + device_id_(tensor.device().index()), + origin_shape_(tensor.sizes().begin(), tensor.sizes().end()) { + CHECK(pt_device_type_ == torch::kCUDA || pt_device_type_ == torch::kCPU); + device_type_ = (pt_device_type_ == torch::kCUDA ? kDLCUDA : kDLCPU); + + char* buf = static_cast(tensor.data_ptr()); + this->origin_buf_ = buf; + this->size_ = tensor.nbytes(); + + // const int alignment = 64; + const int alignment = tvm::runtime::kAllocAlignment; + char* aligned = reinterpret_cast(((uint64_t)buf + alignment - 1) & (~(alignment - 1))); + if (buf == aligned) { + this->tensor_ = tensor; + this->buf_ = buf; + this->offset_ = 0; + } else { + const auto options = + torch::TensorOptions().dtype(tensor.dtype()).device(pt_device_type_, device_id_); + this->inline_tensor_ = + torch::empty({static_cast(tensor.nbytes() + alignment)}, options); + this->tensor_ = this->inline_tensor_; + + buf = static_cast(this->tensor_.data_ptr()); + char* buf_aligned = reinterpret_cast(((uint64_t)buf + alignment) & (~(alignment - 1))); + this->buf_ = buf; + this->offset_ = buf_aligned - buf; + } + } + + void CopyToOrigin() { + if (buf_ == origin_buf_) { + return; + } + if (device_type_ == kDLCPU) { + memcpy(origin_buf_, buf_ + offset_, size_); +#ifdef PT_TVMDSOOP_ENABLE_GPU + } else if (device_type_ == kDLCUDA) { + cudaMemcpy(origin_buf_, buf_ + offset_, size_, cudaMemcpyDeviceToDevice); +#endif + } else { + LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type_ + << " is not implemented currently"; + } + } + + void CopyFromOrigin() { + if (buf_ == origin_buf_) { + return; + } + if (device_type_ == kDLCPU) { + memcpy(buf_ + offset_, origin_buf_, size_); +#ifdef PT_TVMDSOOP_ENABLE_GPU + } else if (device_type_ == kDLCUDA) { + cudaMemcpy(buf_ + offset_, origin_buf_, size_, cudaMemcpyDeviceToDevice); +#endif + } else { + LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type_ + << " is not implemented currently"; + } + } + + // Create DLPack tensor from PyTorch tensor + void MakeDLTensor(DLTensor* out) { + const DLDevice dl_ctx{DLDeviceType(device_type_), device_id_}; + DLDataType dlpack_type; + const auto& tensor = this->tensor_; + CHECK(GetTvmDtype(tensor.dtype(), &dlpack_type)); + + out->device = dl_ctx; + out->ndim = origin_shape_.size(); + out->shape = origin_shape_.data(); + out->strides = nullptr; + out->byte_offset = 0; + out->dtype = dlpack_type; + out->data = buf_ + offset_; + } + + std::string DebugString() { + std::stringstream ss; + ss << "dl device: " << device_type_ << "\npt device: " << static_cast(pt_device_type_) + << "\ndevice_id: " << device_id_ << "\nsize: " << size_ << "\noffset: " << offset_ + << "\nshape:"; + for (auto dim : origin_shape_) { + ss << ' ' << dim; + } + ss << std::endl; + return ss.str(); + } + + private: + DLDeviceType device_type_; + c10::DeviceType pt_device_type_; + int device_id_; + + at::Tensor inline_tensor_; + at::Tensor tensor_; + size_t size_; + size_t offset_; + + std::vector origin_shape_; + + char* origin_buf_; + char* buf_; +}; +} // namespace pytorch +} // namespace contrib +} // namespace tvm +#endif // TVM_CONTRIB_TORCH_UTILS_H_ diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 24cae798988e..ad1f51ba6d71 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -166,7 +166,7 @@ transform::Pass BindTarget(Target target) { } static transform::Pass AnnotateEntryFunc(bool b) { - auto fpass = [b](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); }; return tir::transform::CreatePrimFuncPass(fpass, 0, "AnnotateEntryFunc", {}); @@ -237,10 +237,10 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -590,7 +590,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (ShouldAnnotateEntryFunc(target, mixed_mod)) { mixed_pass_list.push_back(AnnotateEntryFunc(true)); @@ -603,11 +602,16 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - if (target->GetAttr("unpacked-api").value_or(Bool(false))) { + // The host Target contains these parameters at the moment rather than + // the specific Target + // TODO(Mousius) - Move these to the Executor object rather than Target + if (target->GetHost().value()->GetAttr("unpacked-api").value_or(Bool(false))) { mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI()); } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1)); diff --git a/src/ir/module.cc b/src/ir/module.cc index 3deb70dd766c..8ea83cfb40f0 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -187,9 +187,12 @@ void WarnIfMalformed(const IRModule& mod, relay::Function func) { auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); // TODO(@jroesch): refactor to use diagnostic context - ICHECK_EQ(fv.size(), 0) << "There are free variables: " << fv << std::endl; - ICHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv - << " in function: " << AsText(func, false); + ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free variables: " << fv; + ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl + << PrettyPrint(func) << std::endl + << "contains free type variables: " << fv; } void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc new file mode 100644 index 000000000000..cf4262814947 --- /dev/null +++ b/src/meta_schedule/integration.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/**************** Utility functions ****************/ + +template +bool HasOnlyOneFunction(const IRModule& mod) { + if (mod->functions.size() != 1) { + return false; + } + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (!func->IsInstance()) { + return false; + } + } + return true; +} + +/**************** ExtractedTask ****************/ + +ExtractedTask::ExtractedTask(String task_name, IRModule mod, Array dispatched) { + ObjectPtr n = make_object(); + n->task_name = task_name; + n->mod = mod; + n->dispatched = dispatched; + data_ = n; +} + +/**************** MetaScheduleContext ****************/ + +struct MetaScheduleContextThreadLocalEntry { + Optional ctx; +}; + +using MetaScheduleContextThreadLocalStore = + dmlc::ThreadLocalStore; + +Optional MetaScheduleContext::Current() { + return MetaScheduleContextThreadLocalStore::Get()->ctx; +} + +void MetaScheduleContext::EnterWithScope() { + Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; + CHECK(!ctx.defined()) + << "ValueError: Nested MetaScheduleContext context managers are not allowed"; + ctx = *this; +} + +void MetaScheduleContext::ExitWithScope() { + Optional& ctx = MetaScheduleContextThreadLocalStore::Get()->ctx; + ICHECK(ctx.defined()); + ctx = NullOpt; +} + +Optional MetaScheduleContext::QueryInsideWithScope( + runtime::String task_name, IRModule mod, Optional> dispatched) { + if (Optional ctx = MetaScheduleContext::Current()) { + return ctx.value()->Query(task_name, mod, dispatched); + } + return NullOpt; +} + +/**************** TaskExtraction ****************/ + +TaskExtraction::TaskExtraction() { + ObjectPtr n = make_object(); + n->tasks = Array(); + data_ = n; +} + +Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + ICHECK(HasOnlyOneFunction(mod)) << mod; + tasks.push_back(ExtractedTask(task_name, mod, {prim_mod})); + return NullOpt; +} + +/**************** ApplyHistoryBest ****************/ + +ApplyHistoryBest::ApplyHistoryBest(Database database) { + ObjectPtr n = make_object(); + n->database = database; + data_ = n; +} + +Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, + Optional> dispatched) { + throw; +} + +/**************** FFI ****************/ + +class MetaScheduleContextInternal { + public: + static void EnterScope(MetaScheduleContext ctx) { ctx.EnterWithScope(); } + static void ExitScope(MetaScheduleContext ctx) { ctx.ExitWithScope(); } +}; + +TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); +TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); +TVM_REGISTER_NODE_TYPE(TaskExtractionNode); +TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") + .set_body_typed([](String task_name, IRModule mod, + Array dispatched) -> ExtractedTask { + return ExtractedTask(task_name, mod, dispatched); + }); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextEnterScope") + .set_body_typed(MetaScheduleContextInternal::EnterScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextExitScope") + .set_body_typed(MetaScheduleContextInternal::ExitScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextCurrent") + .set_body_typed(MetaScheduleContext::Current); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") + .set_body_typed(MetaScheduleContext::QueryInsideWithScope); +TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") + .set_body_method(&MetaScheduleContextNode::Query); +TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { + return TaskExtraction(); +}); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index a529f2354d87..3ef5026cae98 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,7 +52,9 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, Builder builder, Runner runner, +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // Database database) { ObjectPtr n = make_object(); n->tasks = tasks; diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index cf0af3d55fe4..08f2b4f451bd 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -92,19 +92,30 @@ Array SendToRunner(const Runner& runner, // return results; } +void TaskSchedulerNode::InitializeTask(int task_id) { + TuneContext task = this->tasks[task_id]; + // Derive the values. + IRModule mod = task->mod.value(); + SpaceGenerator space = task->space_generator.value(); + SearchStrategy strategy = task->search_strategy.value(); + // Initialize Modules. + space->InitializeWithTuneContext(task); + strategy->InitializeWithTuneContext(task); +} + void TaskSchedulerNode::Tune() { - for (const TuneContext& task : this->tasks) { - CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(task->space_generator.defined()) + for (int i = 0; i < static_cast(this->tasks.size()); i++) { + // Check Optional value validity. + CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(tasks[i]->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(task->search_strategy.defined()) + CHECK(tasks[i]->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); - strategy->PreTuning(space->GenerateDesignSpace(mod)); + + InitializeTask(i); + + tasks[i]->search_strategy.value()->PreTuning( + tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value())); } int running_tasks = tasks.size(); @@ -114,7 +125,7 @@ void TaskSchedulerNode::Tune() { ICHECK(!task->is_stopped); ICHECK(!task->runner_futures.defined()); SearchStrategy strategy = task->search_strategy.value(); - if (task->measure_candidates = strategy->GenerateMeasureCandidates()) { + if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { Array builder_results = SendToBuilder(this->builder, task, task->measure_candidates.value()); task->runner_futures = @@ -186,13 +197,23 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { } TaskScheduler TaskScheduler::PyTaskScheduler( + Array tasks, // + Builder builder, // + Runner runner, // + Database database, // PyTaskSchedulerNode::FTune f_tune, // + PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, // PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, // PyTaskSchedulerNode::FNextTaskId f_next_task_id) { ObjectPtr n = make_object(); + n->tasks = tasks; + n->builder = builder; + n->runner = runner; + n->database = database; n->f_tune = f_tune; + n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; n->f_is_task_running = f_is_task_running; n->f_join_running_task = f_join_running_task; @@ -202,14 +223,16 @@ TaskScheduler TaskScheduler::PyTaskScheduler( TVM_REGISTER_OBJECT_TYPE(TaskSchedulerNode); TVM_REGISTER_NODE_TYPE(PyTaskSchedulerNode); -TVM_REGISTER_GLOBAL("tvm.task.TaskSchedulerPyTaskScheduler") +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerPyTaskScheduler") .set_body_typed(TaskScheduler::PyTaskScheduler); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") + .set_body_method(&TaskSchedulerNode::Tune); +TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerInitializeTask") + .set_body_method(&TaskSchedulerNode::InitializeTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerSetTaskStopped") .set_body_method(&TaskSchedulerNode::SetTaskStopped); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerIsTaskRunning") .set_body_method(&TaskSchedulerNode::IsTaskRunning); -TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerTune") - .set_body_method(&TaskSchedulerNode::Tune); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerJoinRunningTask") .set_body_method(&TaskSchedulerNode::JoinRunningTask); TVM_REGISTER_GLOBAL("meta_schedule.TaskSchedulerNextTaskId") diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 94dfda556cc9..09eb02e10bfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -93,16 +95,27 @@ class NodeIndexer : public AttrVisitor { MakeIndex(const_cast(value->get())); } - // make index of all the children of node - void MakeIndex(Object* node) { + void MakeNodeIndex(Object* node) { if (node == nullptr) return; ICHECK(node->IsInstance()); - if (node_index_.count(node)) return; + if (node_index_.count(node)) { + return; + } ICHECK_EQ(node_index_.size(), node_list_.size()); node_index_[node] = node_list_.size(); node_list_.push_back(node); + } + // make index of all the children of node + void MakeIndex(Object* node) { + if (node == nullptr) return; + ICHECK(node->IsInstance()); + + if (node_index_.count(node)) { + return; + } + MakeNodeIndex(node); if (node->IsInstance()) { ArrayNode* n = static_cast(node); for (const auto& sp : *n) { @@ -123,6 +136,21 @@ class NodeIndexer : public AttrVisitor { MakeIndex(const_cast(kv.second.get())); } } + } else if (node->IsInstance()) { + auto pre_visit = [this](const relay::LetNode* op) { + MakeNodeIndex(const_cast(static_cast(op))); + MakeIndex(const_cast(static_cast(op->var.get()))); + MakeIndex(const_cast(static_cast(op->value.get()))); + MakeIndex(const_cast(static_cast(op->span.get()))); + MakeIndex(const_cast(static_cast(op->checked_type_.get()))); + if (!op->body.as()) { + MakeIndex(const_cast(static_cast(op->body.get()))); + } + }; + auto post_visit = [](const relay::LetNode* op) {}; + if (!reflection_->GetReprBytes(node, nullptr)) { + relay::ExpandANormalForm(static_cast(node), pre_visit, post_visit); + } } else { // if the node already have repr bytes, no need to visit Attrs. if (!reflection_->GetReprBytes(node, nullptr)) { diff --git a/src/parser/parser.cc b/src/parser/parser.cc index ebd6566889dc..092d5b61eeec 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1437,8 +1437,8 @@ class Parser { String attr_key = Downcast(raw_attrs["attrs_type_key"]); if (attr_key.size()) { raw_attrs.erase("attrs_type_key"); - auto tbl = tvm::ReflectionVTable::Global(); - auto attr_obj = tbl->CreateObject(attr_key, raw_attrs); + auto attr_obj = + tvm::ReflectionVTable::Global()->CreateObject(attr_key, raw_attrs); ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } @@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr") TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() { return CreateModulePass( [](const IRModule& mod, const PassContext& ctx) { - auto text = AsText(mod, true); + String text = AsText(mod, /*show_meta_data=*/true); + VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; return ParseModule("GeneratedSource", text); }, 0, "AnnotateSpans", {}); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ea97bb35a09f..7454cfdf336e 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -37,6 +37,7 @@ #include #include #include +#include #include #include "../ir/attr_functor.h" @@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); - } else if (!show_meta_data_ && node.as()) { - // Show attributes in readable form. - return PrintAttrs(Downcast(node)); } else { // default module. std::ostringstream os; @@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (Var param : fn->params) { params.push_back(AllocVar(param)); } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } doc << Doc::Concat(params) << ") "; @@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { Doc doc; doc << "Tensor[("; std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); + for (const PrimExpr& prim_expr : node->shape) { + // Though not bound within an attribute the attribute visitor will handle the PrimExprs we + // care about. + shapes.push_back(PrintAttributeValue(prim_expr)); } doc << Doc::Concat(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; @@ -766,36 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { // Overload of Attr printing functions //------------------------------------ -Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (auto str_obj = value.as()) { - printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (const auto* on_device_attrs = value.as()) { - printed_attr << "device_type=" << on_device_attrs->device_type; - } else if (meta) { - printed_attr = meta_->GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; - } else { - return Doc::Text("None"); - } -} - Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), true); + // Since we don't have any overload for a specific attribute type we'll need to force + // the meta[...] representation to avoid infinite regress. + return PrintAttributeValue(GetRef(op), /*force_meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { Doc doc; doc << "["; std::vector arr_vals; - for (auto val : *op) { - arr_vals.push_back(PrintAttr(val)); + for (const auto& val : *op) { + arr_vals.push_back(PrintAttributeValue(val)); } doc << Doc::Concat(arr_vals); doc << "]"; @@ -833,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { doc << key << "=" << *value << "f"; docs->push_back(doc); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, int* value) final { PrintKV(key, *value); } @@ -846,7 +829,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { LOG(FATAL) << "do not allow NDarray as argument"; } void Visit(const char* key, runtime::ObjectRef* obj) final { - PrintKV(key, parent_->PrintAttr(*obj)); + PrintKV(key, parent_->PrintAttributeValue(*obj)); } private: @@ -854,50 +837,126 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { RelayTextPrinter* parent_; }; -Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) { - std::vector docs; - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - Doc doc; - doc << "{" << Doc::Concat(docs) << "}"; - - return doc; +void RelayTextPrinter::AppendGenericAttrs(std::vector* docs, const Attrs& attrs, + bool include_type_key) { + if (!attrs.defined()) { + return; + } + AttrPrinter printer(docs, this); + // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this + // case we are read-only. + const_cast(attrs.get())->VisitNonDefaultAttrs(&printer); + if (include_type_key) { + std::string s = attrs->GetTypeKey(); + printer.Visit("attrs_type_key", &s); + } } std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; - if (!attrs.defined()) return docs; + if (!attrs.defined()) { + return docs; + } const auto* op_node = op.as(); if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) { - // fallback + // The parser can only understand calls with attributes if they match the operator's + // declared attribute type. If that's not the case fall back to the meta[...] representation. + docs.push_back(meta_->GetMetaNode(attrs)); + } else { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node); + } + return docs; +} + +std::vector RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) { + if (!dict_attrs.defined()) { + return {}; + } + return PrintDictAttrs(dict_attrs->dict); +} + +std::vector RelayTextPrinter::PrintDictAttrs(const Map& dict_attrs) { + std::vector docs; + if (!dict_attrs.defined()) { + return docs; + } + for (const auto& k : dict_attrs) { Doc doc; - doc << meta_->GetMetaNode(attrs); + doc << k.first << "=" << PrintAttributeValue(k.second); docs.push_back(doc); - return docs; - } else { - // Show attributes in readable form. - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - if (!op_node) { - // print call attr type key to restore expr for relay parser - std::string s = std::string(attrs->GetTypeKey()); - printer.Visit("attrs_type_key", &s); + } + return docs; +} + +Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(GetRef(str_obj)); + } else if (force_meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); + } else if (const auto* se_scope_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(se_scope_node)); + } else { + // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while + // debugging. + std::ostringstream os; + os << GetRef(se_scope_node); + return Doc::Text(os.str()); + } + } else if (const auto* base_attr_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_attr_node)); + } else { + // Special case: The non-meta form for attributes are much easier to work with while + // debugging. + printed_attr = PrintAttrsAsAttributeValue(GetRef(base_attr_node)); + } + } else if (const auto* base_map_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_map_node)); + } else { + // Special case: Show maps fields as key=value pairs to help debugging. + printed_attr << PrintMapAsAttributeValue(GetRef>(base_map_node)); + } + } else if (const auto* global_var_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(global_var_node)); + } else { + printed_attr << "'" << global_var_node->name_hint << "'"; + } + } else { + printed_attr = VisitAttr(value); } - return docs; + return printed_attr; + } else { + return Doc::Text("None"); } } -std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { +Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) { std::vector docs; - if (!attrs.defined()) return docs; - const auto* dict_attrs = attrs.as(); - ICHECK(dict_attrs); - for (const auto& k : dict_attrs->dict) { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false); + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; +} + +Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& map) { + std::vector docs; + for (const auto& k : map) { Doc doc; - doc << k.first << "=" << Print(k.second); + doc << PrintAttributeValue(k.first); + doc << "="; + doc << PrintAttributeValue(k.second); docs.push_back(doc); } - return docs; + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; } Doc RelayTextPrinter::PrintSpan(const Span& span) { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index b8533a5d8801..444cb0828c94 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -58,6 +58,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { os << "def @" << kv.first->name_hint; doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); } else if (kv.second.as()) { + doc << "@" << kv.first->name_hint << " = "; doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); } doc << Doc::NewLine(); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a2178167b2e3..ebd667ae2ac7 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -77,9 +77,42 @@ class RelayTextPrinter : public ExprFunctor, // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const ObjectRef& node); Doc PrintFinal(const ObjectRef& node); - Doc PrintAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence + * of key=value entries, if any. + */ + void AppendGenericAttrs(std::vector* docs, const Attrs& attrs, bool include_type_key); + + /*! + * \brief Returns \p attrs printed as a sequence of key=value entries, if any. + * This is used for call attributes. + */ std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. + * This is used for function definition attributes. + */ + std::vector PrintDictAttrs(const DictAttrs& dict_attrs); + std::vector PrintDictAttrs(const Map& dict_attrs); + + /*! + * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta + * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. + */ + Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); + + /*! + * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintAttrsAsAttributeValue(const Attrs& attrs); + + /*! + * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintMapAsAttributeValue(const Map& map); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); @@ -162,7 +195,6 @@ class RelayTextPrinter : public ExprFunctor, //------------------------------------ // Overload of Attr printing functions //------------------------------------ - Doc PrintAttr(const ObjectRef& value, bool meta = false); Doc VisitAttrDefault_(const Object* op) final; Doc VisitAttr_(const ArrayNode* op) final; Doc VisitAttr_(const tir::IntImmNode* op) final; @@ -379,6 +411,9 @@ class TIRTextPrinter : public StmtFunctor, String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate); + } // namespace tir } // namespace tvm diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 302c4491cebe..e479af1b2fe9 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -71,6 +72,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintString(node.as()); } else if (node->IsInstance()) { return PrintBufferRegion(node.as()); + } else if (node->IsInstance()) { + return Doc::Text(node.as()->ToDebugString()); } else { return this->meta_->GetMetaNode(node); } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 13e4cfcd30ba..a47712e6b62a 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -91,7 +91,7 @@ class TVMScriptPrinter : public StmtFunctor, */ TVM_DLL Doc Print(const ObjectRef& node); - private: + protected: /*! \brief The tir prefix */ String tir_prefix_; /*! \brief whether show meta data */ @@ -119,8 +119,6 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ std::unordered_map memo_buf_decl_; - /*! \brief Map from CommReducer to Doc */ - std::unordered_map memo_reducer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief number of children of current node's parent */ @@ -208,8 +206,10 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintBlockVars(const BlockRealizeNode* op); Doc PrintBlockAttr(const BlockRealizeNode* op); Doc PrintBlockBody(const BlockNode* op); + virtual Doc PrintBlockName(const BlockNode* block_op); Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); + Doc PrintCommReducer(const CommReducerNode* op); Doc PrintAnnotations(const Map& annotations); static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } @@ -217,15 +217,24 @@ class TVMScriptPrinter : public StmtFunctor, Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); void TryDeallocVar(const Var& var); + bool ContainsOptionalInfo(const Stmt& stmt); /*! Helper functions for loop printing. */ /*! * \brief Print a single for loop * \param loop The for loop to be printed */ - Doc PrintLoop(const For& loop); + virtual Doc PrintLoop(const For& loop); /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); + /*! + * \brief Print all simple loops in stack into one line using tir_prefix_.grid(). + * \param for_op the for node to be checked + */ + bool IsSimpleLoop(const ForNode* for_op) { + return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && + is_zero(for_op->min) && !ContainsOptionalInfo(GetRef(for_op)); + } /*! * \brief Print additional info about expr in comment. @@ -234,11 +243,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintOptionalInfo(const Stmt& stmt) { Doc doc; // default annotations - if (annotate_ != nullptr) { + if (ContainsOptionalInfo(stmt)) { std::string annotated_stmt = annotate_(stmt); - if (!annotated_stmt.empty()) { - doc << "# " << annotated_stmt << Doc::NewLine(); - } + doc << "# " << annotated_stmt << Doc::NewLine(); } return doc; } @@ -303,6 +310,17 @@ class TVMScriptPrinter : public StmtFunctor, } return doc; } + + public: + static Doc PrintHeader(const std::string& tir_prefix) { + Doc header; + if (tir_prefix != "tir") { + header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine(); + } else { + header << "# from tvm.script import tir" << Doc::NewLine(); + } + return header; + } }; Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { @@ -391,6 +409,16 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +/*! + * \brief Check if any optional information exists in annotate_ for + * a given Stmt. + * \param stmt The statement. + */ +bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { + if (annotate_ == nullptr) return false; + return !annotate_(stmt).empty(); +} + /*! * \brief Try to dealloc vars out of space and leave the index to coming vars. * \note It is not a necessary step. @@ -427,6 +455,39 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { return doc; } +Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { + Doc doc; + int n_var = static_cast(op->rhs.size()); + + doc << tir_prefix_ << ".comm_reducer(lambda "; + for (const Var& v_lhs : op->lhs) { + doc << Print(v_lhs) << ", "; + } + for (int i = 0; i < n_var; ++i) { + doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", "); + } + if (n_var == 1) { + doc << Print(op->result[0]) << ", "; + } else { + doc << "("; + for (int i = 0; i < n_var; ++i) { + doc << Print(op->result[i]); + if (i != n_var - 1) { + doc << ", "; + } + } + doc << "), "; + } + doc << Print(op->identity_element) << ")"; + + // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda. + for (int i = 0; i < n_var; ++i) { + memo_var_.erase(op->lhs[i]); + memo_var_.erase(op->rhs[i]); + } + return doc; +} + Doc TVMScriptPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("None"); if (node->IsInstance()) { @@ -454,6 +515,8 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintBufferRegion(node.as()); } else if (node->IsInstance()) { return PrintMatchBufferRegion(node.as()); + } else if (node->IsInstance()) { + return PrintCommReducer(node.as()); } else { LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); return Doc(); @@ -526,7 +589,8 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision) -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", ExprPrecedence::kAdditionSubtraction) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", ExprPrecedence::kRelational) @@ -538,17 +602,10 @@ TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", ExprPrecedence::kEquality) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", ExprPrecedence::kAnd) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr) -Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) { +Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << tir_prefix_ << ".floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; - return doc; -} - -Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) { - *out_precedence = ExprPrecedence::kIdentity; - Doc doc; - doc << tir_prefix_ << ".floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } @@ -835,14 +892,14 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { var_not_in_headers_.insert(op->loop_var.get()); loop_var_map_[op->loop_var.get()] = GetRef(op); const auto* body = op->body.as(); - bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + bool simple_loop = IsSimpleLoop(op); if (simple_loop) simple_loop_stack_.push_back(GetRef(op)); // It is a loop that can be compressed, let the loops below print it out - if (simple_loop && body != nullptr) { - Doc result = Print(GetRef(body)); + if (simple_loop && body != nullptr && IsSimpleLoop(body)) { + doc << Print(GetRef(body)); TryDeallocVar(op->loop_var); loop_var_map_.erase(op->loop_var.get()); - return result; + return doc; } // It is a loop that can not be compressed bool print_above = !simple_loop_stack_.empty(); @@ -916,6 +973,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +/*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis."; @@ -1049,15 +1107,25 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { return body; } -Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { - const auto* block_op = op->block.as(); - // print block name and block vars +/*! + * \brief Print the name of a block + * \param block_op The block node to be printed + */ +Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { Doc doc; doc << "with " << tir_prefix_ << ".block("; if (!block_op->name_hint.empty()) { doc << Doc::StrLiteral(block_op->name_hint); } doc << "):"; + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + Doc doc = PrintOptionalInfo(GetRef(block_op)); + // print block name and block vars + doc << PrintBlockName(block_op); Doc block_var = PrintBlockVars(op); // print predicate, binding, read/write tensor region, annotations Doc block_attr_doc = PrintBlockAttr(op); @@ -1124,7 +1192,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { memo_var_.clear(); memo_buf_.clear(); memo_buf_decl_.clear(); - memo_reducer_.clear(); var_not_in_headers_.clear(); buf_not_in_headers_.clear(); // print signature @@ -1149,15 +1216,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second]; body << ")" << Doc::NewLine(); } - // print comm_reducer - for (const auto& it : memo_reducer_) { - body << it.second << " = .comm_reducer("; - var_not_in_headers_.insert(it.first->lhs[0].get()); - var_not_in_headers_.insert(it.first->rhs[0].get()); - body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": " - << Print(it.first->result[0]) << ", " << Print(it.first->identity_element[0]); - body << ")" << Doc::NewLine(); - } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1343,12 +1401,65 @@ Doc TVMScriptPrinter::PrintLoopStack() { return res; } +/*! + * \brief The printer for TVMScript with diagnostic + * \details The printer obtain the precedence of the top-level operation when printing each + * subexpression to decide whether or not parentheses is needed. + */ +class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { + public: + explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) + : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} + + protected: + Doc PrintBlockName(const BlockNode* block_op) override; + Doc PrintUnderline(const Stmt& stmt, int length); + Doc PrintLoop(const For& loop) override; +}; + +Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { + Doc doc = TVMScriptPrinter::PrintBlockName(block_op); + doc << PrintUnderline(GetRef(block_op), doc.str().size()); + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { + Doc doc; + // annotation + if (ContainsOptionalInfo(stmt)) { + String underline = std::string(length, '^'); + doc << Doc::NewLine() << underline; + } + return doc; +} + +Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { + Doc res = TVMScriptPrinter::PrintLoop(loop); + res << PrintUnderline(loop, res.str().size()); + return res; +} + String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); - return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; + Doc doc; + doc << TVMScriptPrinter::PrintHeader(tir_prefix) + << TVMScriptPrinter(tir_prefix, show_meta).Print(mod); + return doc.str() + "\n"; } TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); +String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, + runtime::TypedPackedFunc annotate) { + ICHECK(mod->IsInstance() || mod->IsInstance()); + Doc doc; + doc << TVMScriptPrinter::PrintHeader(tir_prefix) + << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod); + return doc.str() + "\n"; +} + +TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic); + } // namespace tir } // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 22e2e9a71040..1421906a3bbb 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -131,12 +131,11 @@ class TypeSolver::Unifier : public TypeFunctor { Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); if (!resolved.defined()) { - solver_->diag_ctx_.Emit( - Diagnostic::Error(this->span) - << "The Relay type checker is unable to show the following types match.\n" - << "In particular " - << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" - << PrettyPrint(rhs->resolved_type) << "`"); + solver_->Emit(Diagnostic::Error(this->span) + << "The Relay type checker is unable to show the following types match.\n" + << "In particular " + << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" + << PrettyPrint(rhs->resolved_type) << "`"); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -233,11 +232,10 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->diag_ctx_.Emit(Diagnostic::Error(this->span) - << "tensor type `" << PrettyPrint(tt1) << "` has " - << tt1->shape.size() << " dimensions, while `" - << PrettyPrint(tt2) << "` has " << tt2->shape.size() - << " dimensions"); + this->solver_->Emit(Diagnostic::Error(this->span) + << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() + << " dimensions, while `" << PrettyPrint(tt2) << "` has " + << tt2->shape.size() << " dimensions"); return Type(nullptr); } @@ -266,7 +264,7 @@ class TypeSolver::Unifier : public TypeFunctor { err << "dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) << " does not match " << std::get<2>(mismatch) << "."; } - this->solver_->diag_ctx_.Emit(err); + this->solver_->Emit(err); return Type(nullptr); } @@ -526,7 +524,7 @@ class TypeSolver::Merger : public TypeFunctor { // constructor TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) : reporter_(make_object(this)), - current_func(current_func), + current_func_(current_func), diag_ctx_(diag_ctx), module_(diag_ctx->module) { ICHECK(module_.defined()); @@ -618,7 +616,7 @@ bool TypeSolver::Solve() { rnode->resolved = resolved; } catch (const CompileError& err) { - this->diag_ctx_.Emit(Diagnostic::Error(rnode->span) << err.what()); + this->Emit(Diagnostic::Error(rnode->span) << err.what()); rnode->resolved = false; } catch (const Error& e) { ICHECK(false) << e.what(); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 56cea60ceeda..3bde1a1e3746 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -94,7 +94,7 @@ class TypeSolver { * \brief Report a diagnostic. * \param diag The diagnostic to report. */ - void EmitDiagnostic(const Diagnostic& diag); + void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } private: class OccursChecker; @@ -176,13 +176,9 @@ class TypeSolver { /*! \brief Reporter that reports back to self */ TypeReporter reporter_; /*! \brief The global representing the current function. */ - GlobalVar current_func; - - public: + GlobalVar current_func_; /*! \brief The diagnostic context. */ DiagnosticContext diag_ctx_; - - private: /*! \brief The module. */ IRModule module_; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 3c9c35c4f254..58bcccf90879 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -40,6 +41,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" @@ -72,14 +74,34 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { AssignReturnSid(GetRef(op)); } - void DeviceAwareVisitExpr_(const CallNode* op) final { - // create token for the call node. - VisitExpr(op->op); - CreateStorage(op); - for (Expr arg : op->args) { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { + // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case + // where the op of the call is a generic function + + Expr func; + Array args; + + if (call_node->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + func = call_lowered_props.lowered_func; + args = call_lowered_props.arguments; + } else { // Relay functions that have not been lowered and lowered extern functions + func = call_node->op; + args = call_node->args; + if (call_node->op.as()) { // Lowered extern function + ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; + } else { // Relay function which has not been lowered yet + ICHECK(call_node->op.as()) + << "Expected the call to be to a lowered primfunc, a lowered extern function or a " + "unlowered Relay function."; + } + } + VisitExpr(func); + CreateStorage(call_node); + for (const Expr& arg : args) { GetStorage(arg); } - AssignReturnSid(GetRef(op)); + AssignReturnSid(GetRef(call_node)); } void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef(op)); } @@ -287,13 +309,18 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Call a function with a given name + * brief Create a function call + * \param call_lowered_props The lowered function and the arguments to call it with + * \param call The call we got func and args from */ - void CreateFuncCall(Call call, std::string func_name) { + void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) { + std::string func_name = call_lowered_props.lowered_func->name_hint; + tvm::Array args{tvm::tir::StringImm(func_name)}; std::vector create_func_call_stmts; + // Pack the inputs - for (Expr arg : call->args) { + for (const Expr& arg : call_lowered_props.arguments) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[arg])}); @@ -371,21 +398,25 @@ class AOTExecutorCodegen : public MixedModeVisitor { return ss.str(); } - void VisitExpr_(const CallNode* op) override { + void VisitExpr_(const CallNode* call_node) override { // Descend the call tree - for (auto arg : op->args) { - VisitExpr(arg); - } - - if (op->op.as()) { - LOG(FATAL) << "Operators should be transformed away; try applying" - << "the fuse_ops transformation to the expression."; - } else if (op->op.as()) { - GlobalVar node = GetRef(op->op.as()); - CreateFuncCall(GetRef(op), node->name_hint); + CallLoweredProps call_lowered_props; + if (const auto* gvn = call_node->op.as()) { // Lowered extern function + ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes."; + for (const auto& arg : call_node->args) { + VisitExpr(arg); + } + call_lowered_props = CallLoweredProps{GetRef(gvn), call_node->args, {}}; } else { - LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); + ICHECK(call_node->op == CallLoweredOp()) << "Operators should be transformed away; Try " + "applying the fuse_ops transformation to the " + "expression."; + call_lowered_props = GetCallLoweredProps(call_node); + for (const auto& arg : call_lowered_props.arguments) { + VisitExpr(arg); + } } + CreateFuncCall(call_lowered_props, GetRef(call_node)); } void VisitExpr_(const VarNode* op) override { @@ -443,7 +474,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - LOG(FATAL) << "All OpNodes should have been expanded"; + if (GetRef(op) != CallLoweredOp()) { + LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; + } } void VisitExpr_(const IfNode* op) override { LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called"; @@ -715,7 +748,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map targets = args[1]; + TargetMap targets = args[1]; init(mod, targets); }); } else if (name == "codegen") { @@ -758,7 +791,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: - void init(void* mod, Map tmp) { + void init(void* mod, TargetMap tmp) { tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ef82ed617508..4dd12ad1d106 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -33,7 +33,7 @@ #include "../../target/func_registry_generator.h" #include "../../target/source/codegen_source_base.h" -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -43,7 +43,6 @@ Pass LabelOps(); } namespace backend { -using TargetsMap = Map; using namespace tvm::relay::transform; /*! @@ -56,7 +55,7 @@ struct BuildOutput { }; struct ExecutorCodegen { - void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } + void Init(runtime::Module* m, TargetMap targets) { CallFunc("init", m, targets); } void Codegen(const Function& func, String mod_name) { CallFunc("codegen", func, mod_name); } @@ -278,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host, + void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, const String executor, const String mod_name) { for (const auto& pair : targets) { VLOG(0) << "Build target " << pair.first << " = " << pair.second->str(); @@ -295,8 +294,6 @@ class RelayBuildModule : public runtime::ModuleNode { executor_ = executor; CheckAndUpdateHostConsistency(&targets_, &target_host_); BuildRelay(mod, params_, mod_name); - // Clear compile engine so that tuning schedules can be changed between runs. See issue #6096. - CompileEngine::Global()->Clear(); } protected: @@ -309,7 +306,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + IRModule Optimize(IRModule relay_module, const TargetMap& targets, const std::unordered_map& params) { targets_ = targets; // No target_host setup it seems. @@ -446,7 +443,7 @@ class RelayBuildModule : public runtime::ModuleNode { const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); if (!target_host.defined()) target_host = (pf != nullptr) ? Target("llvm") : Target("stackvm"); - // Update all the targets in the targets_ TargetsMap + // Update all the targets in the targets_ TargetMap CheckAndUpdateHostConsistency(&targets_, &target_host); // Relay IRModule -> IRModule optimizations. @@ -504,7 +501,7 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array{}); } } else { - ret_.mod = tvm::build(lowered_funcs, target_host_); + ret_.mod = tvm::build(lowered_funcs, target_host); } auto ext_mods = executor_codegen_->GetExternalModules(); @@ -542,7 +539,7 @@ class RelayBuildModule : public runtime::ModuleNode { protected: std::unique_ptr executor_codegen_; /*! \brief target device */ - TargetsMap targets_; + TargetMap targets_; /*! \brief target host device */ tvm::Target target_host_; /*! \brief parameters */ diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc deleted file mode 100644 index 0e7af2278375..000000000000 --- a/src/relay/backend/compile_engine.cc +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.cc - * \brief Internal compilation engine. - */ -#include "compile_engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../runtime/meta_data.h" -#include "../transforms/pass_utils.h" -#include "te_compiler_cache.h" -#include "utils.h" - -namespace tvm { -namespace relay { - -TVM_REGISTER_OBJECT_TYPE(CompileEngineNode); - -class CompileEngineImpl : public CompileEngineNode { - public: - // Lower the function. - CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) { - return LowerInternal(key, mangle_fn)->cached_func; - } - - CachedFunc Lower(const CCacheKey& key, const String mod_name) { - auto mangle_fn = [mod_name](String name) { return runtime::get_name_mangled(mod_name, name); }; - - return Lower(key, mangle_fn); - } - - // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - auto mangle_fn = [](String name) { return name; }; - CCacheValue value = LowerInternal(key, mangle_fn); - if (value->packed_func != nullptr) return value->packed_func; - auto m = build(value->cached_func->funcs, key->target, Target(nullptr)); - value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint); - return value->packed_func; - } - - CachedFunc LowerShapeFunc(const CCacheKey& key) final { - return LowerShapeFuncInternal(key)->cached_func; - } - - Array LowerExternalFunctions() { - Array ret; - std::unordered_map cached_symbol; - std::vector cached_ext_funcs; - for (const auto& it : cache_) { - auto src_func = it.first->source_func; - ICHECK(src_func.defined()); - - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); - ICHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen.value(); - cached_ext_funcs.push_back(it.first); - - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" - << AsText(src_func, false) << "\n" - << "Functions with external codegen must have the " - << tvm::attr::kGlobalSymbol << " attr set."; - - std::string sn = symbol_name.value(); - if (!cached_symbol.count(sn)) { - cached_symbol[sn] = code_gen_name; - } else { - ICHECK_NE(cached_symbol[sn], code_gen_name) - << "Found duplicated symbol: " << sn << " for: " << code_gen_name; - } - - std::string ext_name = "relay.ext." + code_gen_name; - auto pf = tvm::runtime::Registry::Get(ext_name); - ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; - // No need to keep compiler attribute at this point, functions have been - // extracted for specific codegen. - src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); - runtime::Module ext_mod = (*pf)(src_func); - - // todo(@zhiics, @jroesch): Should this be a user visible error? - ICHECK(ext_mod.defined()) << "No external library was generated for " << ext_name - << "even though it was requested" - "by the annotated function " - << PrettyPrint(src_func); - - ret.push_back(ext_mod); - } - } - - // No need to cache external functions as we collected them all to create - // external runtime modules. - for (const auto& it : cached_ext_funcs) { - cache_.erase(it); - } - return ret; - } - - void Clear() final { cache_.clear(); } - - // List all items in the cache. - Array ListItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - // List all items in the shape_func_cache. - Array ListShapeFuncItems() { - std::lock_guard lock(mutex_); - Array items; - for (auto& kv : shape_func_cache_) { - items.push_back(kv.first); - items.push_back(kv.second); - } - return items; - } - - /*! - * \brief Get the cache key of the function that is being lowered currently - * \return the cache key - */ - CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; } - - private: - // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key, std::function mangle_fn) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = cache_.find(key); - if (it != cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - if (!backend::IsCompileEngineCacheDisabled()) { - cache_[key] = value; - } - } - cur_ccache_key_ = key; - - // No need to lower external functions for now. We will invoke the external - // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { - auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(name_node.defined()) << "External function has not been attached a name yet."; - auto func_name = std::string(name_node.value()); - auto target = Target("ext_dev"); - auto global_var = GlobalVar(func_name); - global_var->checked_type_ = key->source_func->checked_type(); - ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); - return value; - } - - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(mangle_fn(name), &name_map_); - }); - - // Skip lowering for device copy node. - const Expr body = (key->source_func)->body; - if (const CallNode* call_node = body.as()) { - if (call_node->attrs.as()) { - value->cached_func = cfunc; - return value; - } - } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); - } - // lower the function - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); - value->cached_func = cfunc; - - return value; - } - - // implement lowered shape func - CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { - std::lock_guard lock(mutex_); - CCacheValue value; - auto it = shape_func_cache_.find(key); - if (it != shape_func_cache_.end()) { - it->second->use_count += 1; - if (it->second->cached_func.defined()) return it->second; - value = it->second; - } else { - value = CCacheValue(make_object()); - value->use_count = 0; - shape_func_cache_[key] = value; - } - // Enforce use the target. - With target_scope(key->target); - - ICHECK(!value->cached_func.defined()); - using tvm::transform::PassContext; - With fresh_pass_ctx_scope(PassContext::Create()); - - auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) { - return GetUniqueName(name, &name_map_); - }); - - value->cached_func = cached_func; - return value; - } - - /*! \brief compiler cache lock*/ - std::mutex mutex_; - /*! \brief internal name map to get an unique name */ - std::unordered_map name_map_; - /*! \brief internal compiler cache */ - std::unordered_map cache_; - /*! \brief internal compiler cache for shape funcs */ - std::unordered_map shape_func_cache_; - /*! \brief the cache key of the function that is being lowered currently*/ - CCacheKey cur_ccache_key_; -}; - -/*! \brief The global compile engine */ -CompileEngine& CompileEngine::Global() { - // intentionally allocate raw pointer to avoid - // free during destructuion. - static CompileEngine* inst = new CompileEngine(make_object()); - return *inst; -} - -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); - -TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") - .set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); - }); - -TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { - return CompileEngine::Global(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { - self->Clear(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key, const String mod_name) { - return self->Lower(key, mod_name); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") - .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListItems(); -}); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListShapeFuncItems") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->ListShapeFuncItems(); - }); - -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGetCurrentCCacheKey") - .set_body_typed([](CompileEngine self) { - CompileEngineImpl* ptr = dynamic_cast(self.operator->()); - ICHECK(ptr != nullptr); - return ptr->GetCurrentCCacheKey(); - }); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h deleted file mode 100644 index 4afdc6d30485..000000000000 --- a/src/relay/backend/compile_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file relay/backend/compile_engine.h - * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns. - * - * This layer represents the older design of the Relay compilation flow and is being deprecated - * in favor of te_compiler.h which is a migration step towards a standard pass based lowering of - * Relay functions. - * - */ -#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ -#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "te_compiler_cache.h" - -namespace tvm { -namespace relay { - -using namespace tvm::relay::tec; - -/*! - * \brief Backend compilation engine for - * low level code generation. - */ -class CompileEngineNode : public Object { - public: - /*! \brief destructor */ - virtual ~CompileEngineNode() {} - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The mangling function for mangling names. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, std::function mangle_fn) = 0; - - /*! - * \brief Get lowered result. - * \param key The key to the cached function. - * \param mod_name The module name to mangle the functions. - * \return The result. - */ - virtual CachedFunc Lower(const CCacheKey& key, const String mangle_fn) = 0; - /*! - * \brief Just in time compile to get a PackedFunc. - * \param key The key to the cached function. - * \return The result. - */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; - /*! - * \brief Lower the shape function. - * \param key The key to the cached function. - * \return The result. - */ - virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; - /*! - * \brief Lower the external function using external codegen tools. - * \return The runtime moduels for each needed external codegen tool. - */ - virtual tvm::Array LowerExternalFunctions() = 0; - - /*! \brief clear the cache. */ - virtual void Clear() = 0; - - // VisitAttrs - void VisitAttrs(AttrVisitor*) {} - - static constexpr const char* _type_key = "relay.CompileEngine"; - TVM_DECLARE_FINAL_OBJECT_INFO(CompileEngineNode, Object); -}; - -/*! \brief cache entry used in compile engine */ -class CompileEngine : public ObjectRef { - public: - CompileEngine() {} - explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { return static_cast(get_mutable()); } - using ContainerType = CompileEngineNode; - /*! \brief The global compile engine. */ - TVM_DLL static CompileEngine& Global(); -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 3c3346340f04..bd0ac52330d5 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -17,6 +17,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -33,29 +34,46 @@ namespace relay { namespace contrib { namespace cmsisnn { -class RelayToTIRVisitor : public MixedModeVisitor { +class RelayToTIRVisitor : public MixedModeMutator { public: - explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {} + explicit RelayToTIRVisitor(IRModule ir_module, Target target) + : ir_module_(ir_module), target_(target) {} - tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; } + IRModule Mutate() { + GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); + BaseFunc main = ir_module_->Lookup(main_global_var); + Function main_func = GetRef(main.as()); + + // Copy everything across and mutate the body + Function mutated_main = + Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, + main_func->type_params, main_func->attrs, main_func->span); + + ir_module_->Update(main_global_var, mutated_main); + + return ir_module_; + } private: inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); } - void CreatePrimFuncForExtern(Array func_signature, + void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, tvm::Array call_extern_args) { Map dict_attrs; - dict_attrs.Set("global_symbol", func_name_); + dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); + dict_attrs.Set(tvm::attr::kTarget, target_); dict_attrs.Set("tir.noalias", Bool(true)); tir::Stmt body = tir::Evaluate( tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); - primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map(), - DictAttrs(dict_attrs)); + tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), + DictAttrs(dict_attrs)); + + ir_module_->Add(global_var, replacement_func); } - void EmitSoftMax(const Expr& expr) { + void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { auto* quantize_call = expr.as(); auto* softmax_call = quantize_call->args[0].as(); auto* dequant_call = softmax_call->args[0].as(); @@ -102,10 +120,10 @@ class RelayToTIRVisitor : public MixedModeVisitor { out_var, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void EmitMul(const Expr& expr) { + void EmitMul(const GlobalVar& global_var, const Expr& expr) { auto* mul_call = expr.as(); const float input_0_scale = GetScalarFromConstant(mul_call->args[2]); @@ -145,10 +163,10 @@ class RelayToTIRVisitor : public MixedModeVisitor { tensor_size, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void EmitAdd(const Expr& expr) { + void EmitAdd(const GlobalVar& global_var, const Expr& expr) { auto* add_call = expr.as(); const float input_0_scale = GetScalarFromConstant(add_call->args[2]); @@ -212,58 +230,59 @@ class RelayToTIRVisitor : public MixedModeVisitor { tensor_size, }; - CreatePrimFuncForExtern(func_signature, args); + CreatePrimFuncForExtern(global_var, func_signature, args); } - void VisitExpr_(const CallNode* call) final { - auto* func = call->op.as(); - if (func == nullptr) { - return; - } - - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - if (comp_name == "cmsisnn.quantized_softmax") { - EmitSoftMax(func->body); + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call = post.as()) { + auto* func = call->op.as(); + if (func == nullptr) { + return post; } - if (comp_name == "cmsisnn.quantized_mul") { - EmitMul(func->body); - } - if (comp_name == "cmsisnn.quantized_add") { - EmitAdd(func->body); + + auto codegen_name = func->GetAttr(attr::kCompiler); + if (codegen_name.defined() && codegen_name == "cmsis-nn") { + const CallNode* inner_call = func->body.as(); + const FunctionNode* composite_func = inner_call->op.as(); + auto comp_name = composite_func->GetAttr(attr::kComposite); + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + + GlobalVar new_global_var(func_name.value()); + new_global_var->checked_type_ = composite_func->checked_type(); + + if (comp_name == "cmsis-nn.quantized_softmax") { + EmitSoftMax(new_global_var, composite_func->body); + } + if (comp_name == "cmsis-nn.quantized_mul") { + EmitMul(new_global_var, composite_func->body); + } + if (comp_name == "cmsis-nn.quantized_add") { + EmitAdd(new_global_var, composite_func->body); + } + + Array args; + for (const auto& arg : call->args) { + args.push_back(VisitExpr(arg)); + } + + return Call(new_global_var, args, call->attrs, call->type_args, call->span); } } - } - - public: - String func_name_; - tir::PrimFunc primfunc_; -}; - -IRModule GenerateTIR(IRModule mod) { - String func_name; - Function func; - // Obtain external Relay Function that needs to be translated into TIR - ICHECK(mod->functions.size() == 1) << "Supports modules with single external Relay function."; - for (auto kv : mod->functions) { - func = Downcast(kv.second); - func_name = func->GetAttr(tvm::attr::kGlobalSymbol).value(); + return post; } - // Prepare PrimFunc from Relay Function - auto relay_to_tir = RelayToTIRVisitor(func_name); - relay_to_tir.VisitExpr(func->body); - - // Build the TIR IRModule from the generated PrimFunc - Map var_func_map; - var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc()); - return IRModule(var_func_map); -} + private: + IRModule ir_module_; + Target target_; +}; -transform::Pass RelayToTIR() { +tvm::transform::Pass RelayToTIR() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, transform::PassContext pc) { return GenerateTIR(m); }; + [=](IRModule ir_module, transform::PassContext pass_context) { + auto relay_to_tir = RelayToTIRVisitor(ir_module, Target("cmsis-nn")); + return relay_to_tir.Mutate(); + }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); } diff --git a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc b/src/relay/backend/contrib/cmsisnn/target.cc similarity index 57% rename from src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc rename to src/relay/backend/contrib/cmsisnn/target.cc index c8094109771b..99bc0bc7cb20 100644 --- a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -16,34 +17,22 @@ * specific language governing permissions and limitations * under the License. */ + #include -#include -#include +#include namespace tvm { + namespace relay { namespace contrib { namespace cmsisnn { -transform::Pass RelayToTIR(); - -runtime::Module CompileCMSISNN(const ObjectRef& ref) { - IRModule relay_mod; - Function relay_func = Downcast(ref); - auto func_name = relay_func->GetAttr(tvm::attr::kGlobalSymbol); - GlobalVar var = GlobalVar(func_name.value()); - relay_mod->Add(var, relay_func); - relay_mod = transform::InferType()(relay_mod); - - Array pass_seqs{transform::InferType(), RelayToTIR()}; - transform::Sequential seq(pass_seqs); - IRModule tir_mod = seq(relay_mod); - - const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate"); - return (*pf)(tir_mod); -} +tvm::transform::Pass RelayToTIR(); +runtime::Module TIRToRuntime(IRModule mod, Target target); -TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN); +TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU) + .set_attr("RelayToTIR", RelayToTIR()) + .set_attr("TIRToRuntime", TIRToRuntime); } // namespace cmsisnn } // namespace contrib diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index fb612e70311b..7350107d186c 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -25,21 +25,23 @@ #include "../../../../runtime/file_utils.h" #include "../../../../target/source/codegen_c.h" +#include "../../../../target/source/codegen_c_host.h" namespace tvm { -namespace codegen { - using namespace tir; +namespace relay { +namespace contrib { +namespace cmsisnn { -class CodeGenCMSISNN : public CodeGenC { +class CodeGenCMSISNN : public codegen::CodeGenCHost { public: - void Init(bool output_ssa) { + void Init(bool output_ssa, bool emit_asserts, std::string target_str) { decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; - CodeGenC::Init(output_ssa); + CodeGenCHost::Init(output_ssa, emit_asserts, target_str); } /*! @@ -47,92 +49,26 @@ class CodeGenCMSISNN : public CodeGenC { * * \return string of code that offloads a subgraph to the Cortex-M */ - void AddFunction(const PrimFunc& prim_func) { - PrintExternCPrefix(stream); - CodeGenC::AddFunction(prim_func); - PrintExternCPostfix(stream); - } - - private: - /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ - void PrintExternCPrefix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "extern \"C\" {\n"; - ss << "#endif\n"; - } - - /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ - void PrintExternCPostfix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "}\n"; - ss << "#endif\n"; - } + void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } }; -class CMSISNNModuleNode : public runtime::ModuleNode { - public: - CMSISNNModuleNode(const std::string& code, const std::string& fmt, - const Array& func_names) - : code_(code), fmt_(fmt), func_names_(func_names) {} - - std::string GetSource(const std::string& format) final { return code_; } - - const char* type_key() const { return "c"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); - } else { - return PackedFunc(nullptr); - } - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = runtime::GetFileFormat(file_name, format); - std::string meta_file = runtime::GetMetaFilePath(file_name); - if (fmt == "c") { - ICHECK_NE(code_.length(), 0); - runtime::SaveBinaryToFile(file_name, code_); - } else { - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - } - } - - protected: - std::string code_; - std::string fmt_; - Array func_names_; -}; - -static runtime::Module CMSISNNModuleNodeCreate(IRModule mod) { +runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; - CodeGenCMSISNN cg; + bool emit_asserts = false; + CodeGenCMSISNN codegen; Array function_names; - cg.Init(output_ssa); - ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; + codegen.Init(output_ssa, emit_asserts, target->str()); for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodegenCMSISNN: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); function_names.push_back(global_symbol.value()); - cg.AddFunction(f); + codegen.AddFunction(prim_func); } - std::string code = cg.Finish(); - auto n = make_object(code, "c", function_names); - return runtime::Module(n); + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); } -TVM_REGISTER_GLOBAL("runtime.CMSISNNModuleNodeCreate").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CMSISNNModuleNodeCreate(args[0]); -}); - -} // namespace codegen +} // namespace cmsisnn +} // namespace contrib +} // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 0d575b3ec498..964d7dee3ad1 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -44,6 +44,12 @@ struct Output { bool need_copy; }; +struct GenerateBodyOutput { + std::string decl; + std::vector buffers; + std::vector outputs; +}; + class CSourceModuleCodegenBase { public: CSourceModuleCodegenBase() = default; @@ -154,7 +160,8 @@ class CodegenCBase { * \endcode */ void GenerateBackendCFunc(const std::string& func_name, const Array& args, - const std::string& const_arr_name, const std::vector& outs) { + const std::string& const_arr_name, const std::vector& outs, + bool pass_dl_tensor = false) { // Print signature code_stream_ << "\n"; @@ -175,8 +182,12 @@ class CodegenCBase { PrintIndents(); code_stream_ << func_name << "_("; for (size_t i = 0; i < args.size(); i++) { - const auto& dtype_str = GetDtypeString(args[i]); - code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + if (pass_dl_tensor) { + code_stream_ << "arg" << i << ",\n"; + } else { + const auto& dtype_str = GetDtypeString(args[i]); + code_stream_ << "(" << dtype_str << "*)(arg" << i << "->data),\n"; + } PrintIndents(); } for (size_t i = 0; i < outs.size() - 1; i++) { diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc new file mode 100644 index 000000000000..f154f8641a64 --- /dev/null +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/cutlass/codegen.cc + * \brief Implementation of CUTLASS codegen. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../../utils.h" +#include "../codegen_c/codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +using namespace backend; +using Str2StrMap = std::unordered_map; + +static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, {"float32", "float"}}; + +constexpr const char* kAnyDim = "Any"; + +std::string GetDimAsStr(ObjectRef dim) { + if (auto d = dim.as()) { + return std::to_string(d->value); + } + return kAnyDim; +} + +inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { + for (int i = 0; i < indent; ++i) { + os << " "; + } + os << stmt; +} + +Str2StrMap GemmArgsCommon(const Map& attrs) { + Str2StrMap args; + auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); + auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); + auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); + args["ElementInputA"] = dtype_map.at(arg0_dtype); + args["ElementInputB"] = dtype_map.at(arg1_dtype); + args["ElementOutput"] = dtype_map.at(ret_dtype); + args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); + args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); + args["op_type"] = std::string(attrs["op_type"].as()->data); + args["lda"] = std::string(attrs["lda"].as()->data); + args["ldb"] = std::string(attrs["ldb"].as()->data); + args["ldc"] = std::string(attrs["ldc"].as()->data); + return args; +} + +Str2StrMap DenseArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(0)); + args["K"] = GetDimAsStr(arg0_shape->at(1)); + args["N"] = GetDimAsStr(arg1_shape->at(0)); + return args; +} + +Str2StrMap BatchMatmulArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + args["batch"] = GetDimAsStr(attrs["batch"]); + args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]); + args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]); + args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(1)); + args["K"] = GetDimAsStr(arg0_shape->at(2)); + args["N"] = GetDimAsStr(arg1_shape->at(1)); + return args; +} + +void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, + const std::vector& func_args, const std::string& kernel, + bool has_bias, bool is_gelu, int m_axis_idx, int n_axis_idx, int k_axis_idx) { + CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); + CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); + CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(gemm_decl, attrs.at("op_def")); + CutlassPrint(gemm_decl, "using Gemm = Operation_" + attrs.at("op_name") + ";\n"); + + auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return func_args[arg_idx] + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); + CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + if (is_gelu) { + // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } else { + CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } + + ICHECK(func_args.size() >= 2); + CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n"); + + CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); + CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); + CutlassPrint(gemm_decl, " problem_size,\n"); +} + +void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + CutlassPrint(gemm_decl, + "size_t workspace_size = " + kernel + "::get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(gemm_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Instantiate CUTLASS kernel depending on template + CutlassPrint(gemm_decl, kernel + " gemm_op;\n"); + + // Check the problem size is supported or not + CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(gemm_decl, "status = gemm_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Launch initialized CUTLASS kernel + CutlassPrint(gemm_decl, "status = gemm_op();\n"); + CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); +} + +std::string DenseOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + bool has_bias = false; + bool is_gelu = + attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 + if (attrs.at("op_type") == "cutlass.dense_bias" || + attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) { + has_bias = true; + } + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + if (has_bias) { + CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + } + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + if (has_bias && !is_gelu) { + CutlassPrint(gemm_decl, " {alpha},\n"); + } else { + // For GeLU, we explicitly specify the scale. + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + } + CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices + + AppendGemmExecute(gemm_decl, "Gemm"); + return gemm_decl.str(); +} + +std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 1, 2); + + auto get_batch_stride = [&attrs, &func_args](const std::string& name, int arg0_idx, int arg1_idx, + int arg0_axis_idx, int arg1_axis_idx) { + if (attrs.at(name) == kAnyDim) { + return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) + "] * " + + func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) + "]"; + } else { + return attrs.at(name); + } + }; + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + + if (attrs.at("batch") == kAnyDim) { + CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n"); + } else { + CutlassPrint(gemm_decl, attrs.at("batch") + "};\n"); + } + + AppendGemmExecute(gemm_decl, "BatchedGemm"); + return gemm_decl.str(); +} + +class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { + public: + CodegenCutlass(const std::string& id, const Map& attrs) { + this->ext_func_id_ = id; + this->attrs_ = attrs; + } + + std::vector VisitExprDefault_(const Object* op) final { + LOG(FATAL) << "Cutlass codegen doesn't support: " << op->GetTypeKey(); + return {}; + } + + std::vector VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(GetRef(node)); + Output output; + output.name = node->name_hint(); + return {output}; + } + + std::vector VisitExpr_(const CallNode* call) final { + const auto* func = call->op.as(); + ICHECK(func) << "Only composite function is supported for CUTLASS."; + GenerateBodyOutput ret = GenerateCompositeFunctionCall(func, call); + ext_func_body_.push_back(ret.decl); + return ret.outputs; + } + + std::string JIT(const std::vector& out) { + code_stream_ << "void " << ext_func_id_ << "_("; + + for (const auto& arg : ext_func_args_) { + code_stream_ << "DLTensor* " << arg->name_hint() << ", "; + } + for (size_t i = 0; i < out.size() - 1; ++i) { + code_stream_ << out[i].dtype << "* out" << i << ", "; + } + code_stream_ << out.back().dtype << "* out" << out.size() - 1 << ") {\n"; + this->EnterScope(); + + // Function body + for (auto decl : buf_decl_) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : ext_func_body_) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, const_array_name_, out, true); + return code_stream_.str(); + } + + private: + std::vector GetArgumentNames(const CallNode* call) { + std::vector arg_names; + for (size_t i = 0; i < call->args.size(); ++i) { + auto res = VisitExpr(call->args[i]); + for (const auto& out : res) { + arg_names.push_back(out.name); + } + } + return arg_names; + } + + GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee, + const CallNode* caller) { + const auto pattern_name = callee->GetAttr(attr::kComposite); + ICHECK(pattern_name.defined()) << "Only functions with composite attribute are supported."; + + if (pattern_name == "cutlass.dense") { + const auto* dense_call = GetRootCall(callee->body.as(), 0, {"nn.dense"}); + return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->op.as()->name; + const auto* dense_call = + GetRootCall(callee->body.as(), 1, {"nn.dense", add_or_bias_add}); + return GenerateBody(dense_call, "cutlass_dense_bias", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_relu") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* dense_call = + GetRootCall(callee->body.as(), 2, {"nn.dense", add_or_bias_add, "nn.relu"}); + return GenerateBody(dense_call, "cutlass_dense_bias_relu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_gelu_fp16") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; + const auto* dense_call = GetRootCall(callee->body.as(), 8, + {"nn.dense", add_or_bias_add, "multiply", "cast", "erf", + "cast", "multiply", "add", "multiply"}); + return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.dense_bias_gelu_fp32") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[1].as()->op.as()->name; + const auto* dense_call = GetRootCall( + callee->body.as(), 6, + {"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add", "multiply"}); + return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), + DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.batch_matmul") { + const auto* batch_matmul_call = + GetRootCall(callee->body.as(), 0, {"nn.batch_matmul"}); + return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), + BatchMatmulArgs(std::ref(attrs_))); + } + LOG(FATAL) << "Unknown composite function: " << pattern_name; + return {}; + } + + GenerateBodyOutput GenerateBody(const CallNode* root_call, const std::string& func_name, + const std::vector& func_args, + const Str2StrMap& attribute_args) { + // Make function call with input buffers when visiting arguements + ICHECK_GT(func_args.size(), 0); + std::ostringstream decl_stream; + decl_stream << "(" << func_args[0]; + for (size_t i = 1; i < func_args.size(); ++i) { + decl_stream << ", " << func_args[i]; + } + // Analyze the output buffers + std::vector out_types; + if (root_call->checked_type()->IsInstance()) { + auto type_node = root_call->checked_type().as(); + for (auto field : type_node->fields) { + ICHECK(field->IsInstance()); + out_types.push_back(field); + } + } else if (root_call->checked_type()->IsInstance()) { + ICHECK(root_call->checked_type()->IsInstance()); + out_types.push_back(root_call->checked_type()); + } else { + LOG(FATAL) << "Unrecognized type node: " << AsText(root_call->checked_type(), false); + } + GenerateBodyOutput ret; + for (const auto& out_type : out_types) { + const std::string out = "out" + std::to_string(buf_idx_++); + decl_stream << ", " << out; + Output output; + output.name = out; + output.dtype = GetDtypeString(out_type.as()); + output.need_copy = false; + ret.outputs.push_back(output); + } + decl_stream << ");"; + if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" || + func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") { + ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); + } else if (func_name == "cutlass_batch_matmul") { + ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); + } + return ret; + } + /*! \brief The id of the external cutlass ext_func. */ + std::string ext_func_id_{""}; + /*! \brief The attrs of the external cutlass ext_func. */ + Map attrs_; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls CUTLASS kernels. */ + Array ext_func_args_; + /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ + std::vector ext_func_body_; + /*! \brief The array declared to store the constant values. */ + std::string const_array_name_; + /*! \brief The declaration of intermediate buffers. */ + std::vector buf_decl_; +}; // class CodegenCutlass + +class CutlassModuleCodegen : public CSourceModuleCodegenBase { + public: + std::pair> GenCutlassFunc(const Function& func) { + ICHECK(func.defined()) << "Input error: expect a Relay function."; + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); + const auto* attrs = func->attrs.as(); + ICHECK(attrs != nullptr); + const auto dict = attrs->dict; + CodegenCutlass builder(sid, dict); + auto out = builder.VisitExpr(func->body); + code_stream_ << builder.JIT(out); + return {sid, {}}; + } + + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + // create header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + // cutlass header + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + + ICHECK(ref->IsInstance()); + auto res = GenCutlassFunc(Downcast(ref)); + std::string code = code_stream_.str(); + String sym = std::get<0>(res); + Array variables = std::get<1>(res); + // Create a CSource module + const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); + ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; + return (*pf)(code, "cu", Array{sym}, variables); + } + + private: + /*! \brief The code stream that will be compiled by NVCC */ + std::ostringstream code_stream_; +}; // CutlassModuleCodegen + +/*! + * \brief The external cutlass compiler/codegen tool. It takes a Relay + * expression/module and compile it into a runtime module. + */ +runtime::Module CutlassCompiler(const ObjectRef& ref) { + CutlassModuleCodegen cutlass; + return cutlass.CreateCSourceModule(ref); +} + +TVM_REGISTER_GLOBAL("relay.ext.cutlass").set_body_typed(CutlassCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index ae58c2f08e8c..fa1dbc66d8a7 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -231,12 +231,6 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C } private: - struct GenerateBodyOutput { - std::string decl; - std::vector buffers; - std::vector outputs; - }; - std::vector GetArgumentNames(const CallNode* call) { std::vector arg_names; for (size_t i = 0; i < call->args.size(); ++i) { diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 3e675215e7e0..88dee9216a48 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -195,6 +195,20 @@ sl::TensorsAndId MakeOps(const sl::TensorAndId& op) { return ops; } +String MakeVariant(auto configuration) { + String variant = configuration.value()->variant; + // Transform variant string to lowercase for comparison + std::string variant_string = variant.c_str(); + std::transform(variant_string.begin(), variant_string.end(), variant_string.begin(), ::tolower); + std::string variant_n78 = "ethos-n78"; + if (variant_string == variant_n78) { + String tops = configuration.value()->tops; + String ple_ratio = configuration.value()->ple_ratio; + variant = "Ethos-N78_" + tops + "TOPS_" + ple_ratio + "PLE_RATIO"; + } + return variant; +} + NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { // Initialise everything auto ctx = transform::PassContext::Current(); @@ -203,8 +217,9 @@ NetworkWithIDs ConstructNetworkVisitor::Construct(const Function& func) { cfg = AttrsWithDefaultValues(); } NetworkWithIDs network_with_ids; - network_ = sl::CreateNetwork(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + network_ = sl::CreateNetwork( + sl::GetFwAndHwCapabilities(sl::EthosNVariantFromString(MakeVariant(cfg).c_str()), + static_cast(std::stoul(cfg.value()->sram_size)))); network_with_ids.network = network_; operand_table_.clear(); @@ -614,8 +629,9 @@ EthosnError EthosnCompiler::SupportedSetup() { auto cfg = ctx->GetConfig("relay.ext.ethos-n.options").defined() ? ctx->GetConfig("relay.ext.ethos-n.options") : AttrsWithDefaultValues(); - m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + m_Queries = std::make_unique( + sl::GetFwAndHwCapabilities(sl::EthosNVariantFromString(cfg.value()->variant.c_str()), + std::stoul(cfg.value()->sram_size))); if (m_Queries == nullptr) { return EthosnError("Could not initialise Ethos-N compiler isSupported"); } diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index ca2df05e958d..279569596f1b 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -227,7 +227,9 @@ NetworkWithIDs ConstructNetwork(const IRModule& mod, const GlobalVar& var, const /*! \brief Attributes to store the compiler options for Ethos-N */ struct EthosnCompilerConfigNode : public tvm::AttrsNode { String variant; - int sram_size_bytes; + String sram_size; + String tops; + String ple_ratio; bool strategy0; bool strategy1; bool strategy3; @@ -247,9 +249,15 @@ struct EthosnCompilerConfigNode : public tvm::AttrsNode, std::vector> GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network); + /*! + * \brief Query interface used to determine if the Ethos-N hardware supports an operation + * with the supplied parameters. + */ static std::unique_ptr m_Queries; }; diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index cae20210ec4f..c41399e314ef 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -17,14 +17,18 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include +#include #include #include #include #include #include +#include "../../../op/call/call.h" + namespace tvm { namespace relay { namespace contrib { @@ -109,7 +113,13 @@ class ConvertAddToSubtract : public MixedModeMutator { GlobalVar new_global_var(func_name.value()); new_global_var->checked_type_ = func->checked_type(); ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); + + // Since we are replacing the Relay function with a call to a TIR function, we must use the + // call_lowered op. + auto call_lowered_attrs = make_object(); + call_lowered_attrs->metadata.Set("relay_attrs", call->attrs); + return CallLowered(std::move(new_global_var), call->args, + std::move(Attrs(call_lowered_attrs)), call->type_args, call->span); } } diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc new file mode 100644 index 000000000000..3f5c2f4cb00f --- /dev/null +++ b/src/relay/backend/executor.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/executor.cc + * \brief Executor Registry + */ + +#include + +#include "../../node/attr_registry.h" +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ExecutorNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const Executor& executor = Downcast(obj); + p->stream << executor->name; + }); + +/********** Registry-related code **********/ + +using ExecutorRegistry = AttrRegistry; + +Executor Executor::Create(String name, Map attrs) { + const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Executor \"" + name + "\" is not defined"); + } + + for (const auto& kv : attrs) { + if (!reg->key2vtype_.count(kv.first)) { + throw Error("Attribute \"" + kv.first + "\" is not available on this Executor"); + } + std::string expected_type = reg->key2vtype_.at(kv.first).type_key; + std::string actual_type = kv.second->GetTypeKey(); + if (expected_type != actual_type) { + throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + + "\" but instead found \"" + actual_type + "\""); + } + } + + for (const auto& kv : reg->key2default_) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + + return Executor(name, DictAttrs(attrs)); +} + +Array Executor::ListExecutors() { return ExecutorRegistry::Global()->ListAllNames(); } + +Map Executor::ListExecutorOptions(const String& name) { + Map options; + const ExecutorRegEntry* reg = ExecutorRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Executor \"" + name + "\" is not defined"); + } + for (const auto& kv : reg->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + +ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { + return ExecutorRegistry::Global()->RegisterOrGet(name); +} + +/********** Register Executors and options **********/ + +TVM_REGISTER_EXECUTOR("aot") + .add_attr_option("unpacked-api") + .add_attr_option("interface-api"); + +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); + +TVM_REGISTER_EXECUTOR("vm"); + +/********** Registry **********/ + +TVM_REGISTER_GLOBAL("relay.backend.CreateExecutor").set_body_typed(Executor::Create); +TVM_REGISTER_GLOBAL("relay.backend.GetExecutorAttrs").set_body_typed([](const Executor& executor) { + return executor->attrs->dict; +}); + +TVM_REGISTER_GLOBAL("relay.backend.ListExecutors").set_body_typed(Executor::ListExecutors); +TVM_REGISTER_GLOBAL("relay.backend.ListExecutorOptions") + .set_body_typed(Executor::ListExecutorOptions); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index debd669126c4..ac3c835ed648 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler.h" #include "./utils.h" @@ -403,64 +405,75 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, const std::string& func_name, - GraphAttrs attrs) { + std::vector GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) { + Call call = GetRef(call_node); std::vector inputs; - for (auto arg : op->args) { - auto res = VisitExpr(arg); - for (auto nr : res) { - inputs.push_back(nr); - } - } + std::string func_name; - /// An adapted version of the storage optimization for the time being. - bool reshape_only = false; - if (op->attrs.defined()) { - if (auto tir_call_attrs = op->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - if (metadata.count(attr::kReshapeOnly) && - Downcast(metadata[attr::kReshapeOnly])->value == 1) { - reshape_only = true; - } + if (call->op == CallLoweredOp()) { + // Extract function and arguments from the call_lowered op + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - auto relay_attrs = Downcast(tir_call_attrs->metadata["relay_attrs"]); + func_name = call_lowered_props.lowered_func->name_hint; - for (auto p : relay_attrs->dict) { - if (p.second.as()) { - attrs[p.first] = std::string(Downcast(p.second)); + for (const Expr& arg : call_lowered_props.arguments) { + for (auto n : VisitExpr(arg)) { + inputs.push_back(n); + } + } + if (call_lowered_props.attrs.metadata.count("relay_attrs")) { + if (auto relay_attrs = + call_lowered_props.attrs.metadata["relay_attrs"].as()) { + for (auto p : relay_attrs->dict) { + if (p.second.as()) { + attrs[p.first] = std::string(Downcast(p.second)); + } } } } - } - - if (reshape_only && ShareSameStorage(GetRef(op), op->args[0])) { - auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); - return AddNode(node, GetRef(op)); + bool reshape_only = false; + if (call_lowered_props.attrs.metadata.count(attr::kReshapeOnly) && + Downcast(call_lowered_props.attrs.metadata[attr::kReshapeOnly])->value == + 1) { + reshape_only = true; + } + if (reshape_only && + ShareSameStorage(GetRef(call_node), call_lowered_props.arguments[0])) { + auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs); + return AddNode(node, call); + } + } else if (!call_node->attrs.defined()) { // Call is an extern function + std::cout << "call_node: \n" << PrettyPrint(call) << std::endl; + const auto* func = call_node->op.as(); + ICHECK(func) << "Expected the operator to be a global var, but got " + << call_node->op->GetTypeKey(); // getting a relay fn here, not sure why. + func_name = func->name_hint; + + for (const Expr& arg : call_node->args) { + for (auto n : VisitExpr(arg)) { + inputs.push_back(n); + } + } + } else { + LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n" + << "The graph executor code generator expects all calls to be call_lowered, " + << "but found: " << std::endl + << PrettyPrint(call); } // Compute the operator name, because we used the get unique name when generating the kernel. auto op_name = _GetUniqueName(func_name); auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); - return AddNode(node, GetRef(op)); + return AddNode(node, call); } std::vector VisitExpr_(const CallNode* call_node) override { - relay::Call call = GetRef(call_node); auto props = GetOnDeviceProps(call_node); if (props.body.defined()) { // See through "on_device" calls. return VisitExpr(props.body); } - - const auto* global_node = call->op.as(); - ICHECK(global_node) - << "Non-primitive-call nodes should have been transformed away.\n" - << "The graph executor code generator expects all calls to have their callee " - "normalized to a GlobalVar, but found:" - << std::endl - << PrettyPrint(call); - auto prim_fn_name = global_node->name_hint; - return GraphAddCallNode(call_node, prim_fn_name, GraphAttrs()); + return GraphAddCallNode(call_node, GraphAttrs()); } std::vector VisitExpr_(const LetNode* op) override { @@ -619,7 +632,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; + TargetMap tmp = args[1]; tec::TargetMap targets; for (const auto& it : tmp) { auto dev_type = it.first.as(); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 961252a14fa7..4031dfdcd6e7 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -24,6 +24,7 @@ */ #include #include +#include #include #include #include @@ -32,6 +33,7 @@ #include "../../support/arena.h" #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/memory.h" #include "../transforms/device_aware_visitors.h" #include "./utils.h" @@ -139,6 +141,8 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { protected: /*! \brief internal token map */ std::unordered_map> token_map_; + /*! \brief empty token map */ + const std::vector no_tokens_; /*! * \brief Get the necessary token. @@ -146,6 +150,11 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ const std::vector& GetToken(const Expr& expr) { + this->VisitExpr(expr); + // Functions don't require data storage, represented by the empty token + if (expr->checked_type().as()) { + return no_tokens_; + } // See through on_device calls. Expr real_expr = IgnoreOnDevice(expr); this->VisitExpr(real_expr); @@ -159,8 +168,9 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding * the result of evaluating \p op. */ - void CreateToken(const ExprNode* op, bool can_realloc) { - return CreateTokenOnDevice(op, GetInScopeDeviceType(GetRef(op)), can_realloc); + void CreateToken(const ExprNode* expr_node, bool can_realloc) { + return CreateTokenOnDevice(expr_node, GetInScopeDeviceType(GetRef(expr_node)), + can_realloc); } /*! @@ -203,12 +213,12 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; - void DeviceAwareVisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { // create token for the call node. - CreateToken(op, true); + CreateToken(call_node, true); // for each input, visit argument token. - for (Expr arg : op->args) { + for (Expr arg : call_node->args) { for (StorageToken* tok : GetToken(arg)) { tok->ref_counter += 1; } @@ -273,7 +283,6 @@ class StorageAllocator : public StorageAllocaBaseVisitor { << "expressions are assigned with virtual device types. Either all " "or none of the expressions are expected to be annotated."; } - return backend::StaticMemoryPlan(smap); } @@ -320,10 +329,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor { using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; // The call map - void DeviceAwareVisitExpr_(const CallNode* op) final { + void DeviceAwareVisitExpr_(const CallNode* call_node) final { std::vector args; // for each input, visit argument token. - for (Expr arg : op->args) { + + for (const Expr& arg : call_node->args) { + // Note: GetToken skips GlobalVars and handles tuples properly, so we don't need to treat + // call_lowered specially. for (StorageToken* tok : GetToken(arg)) { args.push_back(tok); } @@ -337,20 +349,17 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // // TODO(tvm-team) Update checks of flat memory enablement when we support // opaque-nd memory planning to skip this path. - if (IsReshape(op)) { - // TODO(@electriclilies, jroesch): This check is failing because the size of args is 3 - // I can't figure out where the extra args are coming from, I assume it must be related - // to the relay_attrs field we added to the TIRCallArgs, but I don't know where / how - // that's happening... + + if (IsReshape(call_node)) { ICHECK_EQ(args.size(), 1U); - ReuseInputToken(op, args[0]); + ReuseInputToken(call_node, args[0]); } else { // create token for the call node. - CreateToken(op, true); + CreateToken(call_node, true); } // check if there is orphaned output that can be released immediately. - for (StorageToken* tok : token_map_.at(op)) { + for (StorageToken* tok : token_map_.at(call_node)) { CheckForRelease(tok); } for (StorageToken* tok : args) { @@ -376,12 +385,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { return fn->HasNonzeroAttr(attr::kReshapeOnly); } - if (call->attrs.defined()) { - if (auto tir_call_attrs = call->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - return metadata.count(attr::kReshapeOnly) && - (Downcast(metadata[attr::kReshapeOnly])->value == 1); - } + if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call); + Map metadata = call_lowered_props.attrs.metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); } return false; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ef89fd9c9c6c..4835d7618a2e 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -34,10 +35,12 @@ #include #include #include +#include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/pass_utils.h" -#include "./te_compiler.h" +#include "te_compiler.h" namespace tvm { namespace relay { @@ -292,8 +295,11 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(IRModule unified_mod, Device device, Target target) - : unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} + Interpreter(IRModule unified_mod, CompilationConfig config, Device device) + : unified_mod_(unified_mod), + config_(std::move(config)), + device_(device), + debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -386,12 +392,12 @@ class Interpreter : public ExprFunctor, per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module); auto mod_itr = per_target_module_std_map.find(target); ICHECK(mod_itr != per_target_module_std_map.end()) - << "No target module for target '" << target->str() << "'"; + << "No target module for target " << target->ToDebugString(); const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { ICHECK(target_module->ContainGlobalVar(var->name_hint)) - << "No global var for '" << var->name_hint << "' in module for target '" << target->str() - << "'"; + << "No global var for '" << var->name_hint << "' in module for target " + << target->ToDebugString(); lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); } @@ -407,8 +413,9 @@ class Interpreter : public ExprFunctor, // Extract all the packed functions. for (const auto& var : all_tir_fn_vars) { PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); - ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint - << "' in compiled module for target '" << target->str() << "'"; + ICHECK(packed_func != nullptr) + << "No packed function for global var '" << var->name_hint + << "' in compiled module for target " << target->ToDebugString(); compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } @@ -677,80 +684,94 @@ class Interpreter : public ExprFunctor, } ObjectRef VisitExpr_(const CallNode* call_node) final { - std::vector args; - for (auto arg : call_node->args) { - args.push_back(Eval(arg)); - } + if (call_node->op == CallLoweredOp()) { // Special case: Call a lowered TIR function. + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); - if (call_node->op == OnDeviceOp()) { - // Special case: The call 'on_device(expr)' denotes that expr should be executed on - // a particular device. We can ignore this during interpretation. - ICHECK_EQ(call_node->args.size(), 1UL); - return args[0]; - } + // Evaluate only function args + std::vector args; + for (auto arg : call_lowered_props.arguments) { + args.push_back(Eval(arg)); + } - // We should not find calls to operators after running fusion and lowering. - if (const OpNode* op_node = call_node->op.as()) { - LOG(FATAL) << "found " << op_node->name - << "; operators should have been removed by previous passes; try " - "fusing and lowering"; - } + // TODO(mbs): Make calling convention first-class in Relay. + Array all_prim_fn_vars; + if (call_lowered_props.attrs.metadata.count("all_prim_fn_vars")) { + all_prim_fn_vars = + Downcast>(call_lowered_props.attrs.metadata.at("all_prim_fn_vars")); + } + GlobalVar prim_shape_fn_var; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_var")) { + prim_shape_fn_var = + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_var")); + } + Array all_prim_shape_fn_vars; + if (call_lowered_props.attrs.metadata.count("all_prim_shape_fn_vars")) { + all_prim_shape_fn_vars = Downcast>( + call_lowered_props.attrs.metadata.at("all_prim_shape_fn_vars")); + } + Array prim_shape_fn_states; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_states")) { + prim_shape_fn_states = + Downcast>(call_lowered_props.attrs.metadata.at("prim_shape_fn_states")); + } - if (const ConstructorNode* con = call_node->op.as()) { - // Special case: ADT constructor - return ConstructorValue(con->tag, args, GetRef(con)); - } + size_t num_shape_inputs = 0; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_inputs")) { + num_shape_inputs = static_cast( + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_inputs")) + ->value); + } + size_t num_shape_outputs = 0; + if (call_lowered_props.attrs.metadata.count("prim_shape_fn_num_outputs")) { + num_shape_outputs = static_cast( + Downcast(call_lowered_props.attrs.metadata.at("prim_shape_fn_num_outputs")) + ->value); + } + ICHECK(config_->optional_homogeneous_target.defined()); + return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars, + config_->optional_homogeneous_target, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, config_->host_se_scope->target, args); + } else { // All other calls + // Evaluate all arguments + std::vector args; + for (auto arg : call_node->args) { + args.push_back(Eval(arg)); + } - if (const GlobalVarNode* gvn = call_node->op.as()) { - if (const TIRCallAttrs* attrs = call_node->attrs.as()) { - // Special case: Call a lowered TIR function. - // TODO(mbs): Make calling convention first-class in Relay. - Array all_prim_fn_vars; - if (attrs->metadata.count("all_prim_fn_vars")) { - all_prim_fn_vars = Downcast>(attrs->metadata.at("all_prim_fn_vars")); - } - GlobalVar prim_shape_fn_var; - if (attrs->metadata.count("prim_shape_fn_var")) { - prim_shape_fn_var = Downcast(attrs->metadata.at("prim_shape_fn_var")); - } - Array all_prim_shape_fn_vars; - if (attrs->metadata.count("all_prim_shape_fn_vars")) { - all_prim_shape_fn_vars = - Downcast>(attrs->metadata.at("all_prim_shape_fn_vars")); - } - Array prim_shape_fn_states; - if (attrs->metadata.count("prim_shape_fn_states")) { - prim_shape_fn_states = - Downcast>(attrs->metadata.at("prim_shape_fn_states")); - } - size_t num_shape_inputs = 0; - if (attrs->metadata.count("prim_shape_fn_num_inputs")) { - num_shape_inputs = static_cast( - Downcast(attrs->metadata.at("prim_shape_fn_num_inputs"))->value); - } - size_t num_shape_outputs = 0; - if (attrs->metadata.count("prim_shape_fn_num_outputs")) { - num_shape_outputs = static_cast( - Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); - } + if (call_node->op == OnDeviceOp()) { + // Special case: The call 'on_device(expr)' denotes that expr should be executed on + // a particular device. We can ignore this during interpretation. + ICHECK_EQ(call_node->args.size(), 1UL); + return args[0]; + } + if (const ConstructorNode* con = call_node->op.as()) { + // Special case: ADT constructor - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, target_, - prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, - num_shape_inputs, num_shape_outputs, cpu_target_, args); + return ConstructorValue(con->tag, args, GetRef(con)); } - } - // Now we just evaluate and expect to find a closure. - ObjectRef fn_val = Eval(call_node->op); - if (const InterpreterClosureObj* closure_node = fn_val.as()) { - auto closure = GetRef(closure_node); - return Invoke(closure, args); - } else if (const RecClosureObj* closure_node = fn_val.as()) { - return Invoke(closure_node->clos, args, closure_node->bind); - } else { - LOG(FATAL) << "internal error: type error, expected function value in the call " - << "position"; - return ObjectRef(); + if (const OpNode* op_node = call_node->op.as()) { + // Except for call_lowered and on_device, we should not find calls to operators after + // running fusion and lowering. + LOG(FATAL) << "found " << op_node->name + << "; operators should have been removed by previous passes; try " + "fusing and lowering"; + } + + // Now we just evaluate and expect to find a closure. + // TODO(@electriclilies): How should call_lowered behave with closures? + ObjectRef fn_val = Eval(call_node->op); + if (const InterpreterClosureObj* closure_node = fn_val.as()) { + auto closure = GetRef(closure_node); + return Invoke(closure, args); + } else if (const RecClosureObj* closure_node = fn_val.as()) { + return Invoke(closure_node->clos, args, closure_node->bind); + } else { + LOG(FATAL) << "internal error: type error, expected function value in the call " + << "position"; + return ObjectRef(); + } } } @@ -884,13 +905,11 @@ class Interpreter : public ExprFunctor, // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; + /*! \brief Compilation config describing the available targets. */ + CompilationConfig config_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; - // Unique target describing how to compile for primitives (but not shape functions). - Target target_; - // Default 'CPU' target for shape primitives. - Target cpu_target_{"llvm"}; // Call stack. Stack stack_; // The distinguished 'debug' operator, which is handled specially. @@ -898,25 +917,21 @@ class Interpreter : public ExprFunctor, }; /*! - * Lowers all calls to primitives in \p mod appropriate for device and target. Returns the + * Lowers all calls to primitives in \p mod appropriate for \p config. Returns the * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, Device device, Target target) { - // Things to initialize to pass into tec::LowerTEPass - // We only have one device-specific target. - tec::TargetMap targets = {{device.device_type, target}}; - if (device.device_type != kDLCPU) { - // However some primitives (eg dynamic shape functions) must always execute on the CPU, - // so make sure we have a target for that. - targets.emplace(kDLCPU, Target("llvm")); +IRModule Prepare(IRModule mod, CompilationConfig config) { + tec::TargetMap tec_target_map; + for (const auto& pair : config->legacy_target_map) { + tec_target_map.emplace(static_cast(pair.first->value), pair.second); } - // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), // Figure out which devices should be used to execute. - transform::PlanDevices(device.device_type), + // TODO(mbs): Should ignore all existing annotations when constant folding + transform::PlanDevices(config->default_primitive_se_scope->device_type()), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -926,7 +941,8 @@ IRModule Prepare(IRModule mod, Device device, Target target) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(targets, /*module_name=*/"intrp", [](Function func) { /* no-op */ })}); + tec::LowerTEPass(tec_target_map, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); @@ -979,7 +995,15 @@ class NeedsPreparationVisitor : public ExprVisitor { TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, Target target) { VLOG_CONTEXT << "EvalFunction"; - VLOG(1) << "evaling module:\n" << PrettyPrint(mod) << "and expression:\n" << PrettyPrint(expr); + VLOG(1) << "evaling module:" << std::endl + << PrettyPrint(mod) << "and expression:" << std::endl + << PrettyPrint(expr); + + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); // // Step 1: Prepare mod. @@ -1024,9 +1048,9 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - IRModule lowered_mod = Prepare(mod_with_expr, device, target); + IRModule lowered_mod = Prepare(mod_with_expr, config); - std::shared_ptr intrp = std::make_shared(lowered_mod, device, target); + std::shared_ptr intrp = std::make_shared(lowered_mod, config, device); // // Step 2: Evaluate target function to a closure. @@ -1065,12 +1089,18 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); + std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - IRModule mod = Prepare(mod_and_global.first, device, target); + IRModule mod = Prepare(mod_and_global.first, config); - Interpreter intrp(mod, device, target); + Interpreter intrp(mod, config, device); Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); if (expr.as() == nullptr) { // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc new file mode 100644 index 000000000000..1c08cbd29d1e --- /dev/null +++ b/src/relay/backend/runtime.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/runtime.cc + * \brief Runtime Registry + */ + +#include + +#include "../../node/attr_registry.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(RuntimeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { + const Runtime& runtime = Downcast(obj); + p->stream << runtime->name; + }); + +/********** Registry-related code **********/ + +using RuntimeRegistry = AttrRegistry; + +Runtime Runtime::Create(String name, Map attrs) { + const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Runtime \"" + name + "\" is not defined"); + } + + for (const auto& kv : attrs) { + if (!reg->key2vtype_.count(kv.first)) { + throw Error("Attribute \"" + kv.first + "\" is not available on this Runtime"); + } + std::string expected_type = reg->key2vtype_.at(kv.first).type_key; + std::string actual_type = kv.second->GetTypeKey(); + if (expected_type != actual_type) { + throw Error("Attribute \"" + kv.first + "\" should have type \"" + expected_type + + "\" but instead found \"" + actual_type + "\""); + } + } + + for (const auto& kv : reg->key2default_) { + if (!attrs.count(kv.first)) { + attrs.Set(kv.first, kv.second); + } + } + + return Runtime(name, DictAttrs(attrs)); +} + +Array Runtime::ListRuntimes() { return RuntimeRegistry::Global()->ListAllNames(); } + +Map Runtime::ListRuntimeOptions(const String& name) { + Map options; + const RuntimeRegEntry* reg = RuntimeRegistry::Global()->Get(name); + if (reg == nullptr) { + throw Error("Runtime \"" + name + "\" is not defined"); + } + for (const auto& kv : reg->key2vtype_) { + options.Set(kv.first, kv.second.type_key); + } + return options; +} + +RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { + return RuntimeRegistry::Global()->RegisterOrGet(name); +} + +/********** Register Runtimes and options **********/ + +TVM_REGISTER_RUNTIME("c").add_attr_option("system-lib"); + +TVM_REGISTER_RUNTIME("cpp"); + +/********** Registry **********/ + +TVM_REGISTER_GLOBAL("relay.backend.CreateRuntime").set_body_typed(Runtime::Create); +TVM_REGISTER_GLOBAL("relay.backend.GetRuntimeAttrs").set_body_typed([](const Runtime& runtime) { + return runtime->attrs->dict; +}); + +TVM_REGISTER_GLOBAL("relay.backend.ListRuntimes").set_body_typed(Runtime::ListRuntimes); +TVM_REGISTER_GLOBAL("relay.backend.ListRuntimeOptions").set_body_typed(Runtime::ListRuntimeOptions); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 445602540dbb..915fc22b2052 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,7 @@ #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../transforms/device_aware_visitors.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -222,7 +224,8 @@ class TECompilerImpl : public TECompilerNode { auto global_var = GlobalVar(func_name); global_var->checked_type_ = key->source_func->checked_type(); ir_module->Add(global_var, key->source_func); - value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module); + value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule{nullptr}, + tir::PrimFunc{nullptr}, {}, ir_module); return value; } @@ -243,16 +246,19 @@ class TECompilerImpl : public TECompilerNode { return value; } } - - // NOTE: array will copy on write. - Array all_args = Array(cfunc->inputs); - for (te::Tensor arg : cfunc->outputs) { - all_args.push_back(arg); + if (cfunc->prim_func.defined()) { + cfunc->funcs->Update(cfunc->prim_fn_var, cfunc->prim_func.value()); + } else { + // NOTE: array will copy on write. + Array all_args = Array(cfunc->inputs); + for (te::Tensor arg : cfunc->outputs) { + all_args.push_back(arg); + } + // lower the function + std::unordered_map binds; + auto func_name = cfunc->prim_fn_var->name_hint; + cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); } - - std::unordered_map binds; - auto func_name = cfunc->prim_fn_var->name_hint; - cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds)); value->cached_func = cfunc; return value; } @@ -313,6 +319,46 @@ TECompiler::TECompiler() { data_ = object; } +/*! \brief The global TE compiler */ +TECompiler& TECompiler::Global() { + static TECompiler* inst = new TECompiler(make_object()); + return *inst; +} +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_meta_schedule", Bool); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerGlobal").set_body_typed([]() { + return TECompiler::Global(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); + +TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerClear").set_body_typed([](TECompiler self) { + self->Clear(); +}); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerLower") + .set_body_typed([](TECompiler self, CCacheKey key, const String mod_name) { + return self->Lower(key, mod_name); + }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerJIT") + .set_body_typed([](TECompiler self, CCacheKey key) { return self->JIT(key); }); + +TVM_REGISTER_GLOBAL("relay.backend._TECompilerListItems").set_body_typed([](TECompiler self) { + TECompilerImpl* ptr = dynamic_cast(self.operator->()); + ICHECK(ptr != nullptr); + return ptr->ListItems(); +}); + using AnalysisRemapping = std::unordered_map; std::tuple IsDeviceCopy(const Function& func) { @@ -416,7 +462,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { * to the TIR implementation, and attributes to attach to the call to identify it as * a TIR call. */ - std::pair LowerFunction(Function func, Target target) { + Expr MakeLoweredCall(Function func, Array visited_args, Array type_args, Span span, + Target target) { if (func->GetAttr(attr::kCompiler).defined()) { // BYOC flow. CCacheKey key = CCacheKey(func, target); @@ -424,6 +471,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { ICHECK(ext_func.defined()) << "Lowering returned undefined function for " << ext_func->prim_fn_var->name_hint; + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT Map prim_fns; relay::Function func_with_metadata = func; func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var); @@ -434,87 +482,91 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // act when we process a function. this->process_fn_(func_with_metadata); - // TODO(mbs): Need TIRCallAttrs or equiv so targets know this is an extern. // TODO(mbs): Dynamic shapes? - return {ext_func->prim_fn_var, Attrs()}; - } + // TODO(@mbs, electriclilies): Make extern functions explicit + return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span); - // Non-External Relay Function - VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); - CCacheKey key = CCacheKey(func, target); - CachedFunc lowered_func = compiler_->Lower(key, module_name_); - VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; - - // Collect all the lowered functions produced for this primitive function. - Map prim_fns; - Array all_prim_fn_vars; - for (auto prim_fn : lowered_func->funcs->functions) { - CHECK(prim_fn.second.as()) << "must be a prim fn"; - prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); - all_prim_fn_vars.push_back(prim_fn.first); - VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; - } + } else { + // Non-External Relay Function + VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" + << PrettyPrint(func); + CCacheKey key = CCacheKey(func, target); + CachedFunc lowered_func = compiler_->Lower(key, module_name_); + VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; - // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT - relay::Function func_with_metadata = func; - func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); - func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); - func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); + // Collect all the lowered functions produced for this primitive function. + Map prim_fns; + Array all_prim_fn_vars; + for (auto prim_fn : lowered_func->funcs->functions) { + CHECK(prim_fn.second.as()) << "must be a prim fn"; + prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); + all_prim_fn_vars.push_back(prim_fn.first); + VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; + } - // Provide a callback hook which allows one-level up code generators to - // act when we process a function. - this->process_fn_(func_with_metadata); + // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT + relay::Function func_with_metadata = func; + func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var); + func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); + func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target); - auto tir_call_attrs = make_object(); - if (func->HasNonzeroAttr(attr::kReshapeOnly)) { - tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); - } + // Provide a callback hook which allows one-level up code generators to + // act when we process a function. + this->process_fn_(func_with_metadata); - auto device_copy = IsDeviceCopy(func); - if (std::get<0>(device_copy)) { - // Record that device copy source and destination devices so the device planner can - // still follow along. - auto source_device = std::get<1>(device_copy); - auto dst_device = std::get<2>(device_copy); - tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device)); - tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); - } + auto call_lowered_attrs = make_object(); + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + } - tir_call_attrs->metadata.Set("relay_attrs", func->attrs); - tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); - - if (IsDynamic(func->ret_type)) { - // Also lower the dynamic shape function. - // Shape function keys use the underlying primitive function as their 'function', - // but the generic 'cpu' target as the target since all shape functions run - // on the host cpu irrespective of where the primitive runs. - // TODO(mbs): Cleanup target handling. - Target shape_target("llvm"); - VLOG(1) << "lowering to target '" << shape_target->str() - << "' for dynamic shape function for primitive"; - CCacheKey shape_key(func, shape_target); - CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); - // Capture the shape function's global var and parameters 'states' in call - // annotations so calling convention can be recovered. - // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. - // The way the shape function calling convention is derived and passed to call sites - // via the 'parameter states' could be improved. - tir_call_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - tir_call_attrs->metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - tir_call_attrs->metadata.Set("prim_shape_fn_num_inputs", - Integer(static_cast(lowered_shape_func->inputs.size()))); - tir_call_attrs->metadata.Set("prim_shape_fn_num_outputs", - Integer(static_cast(lowered_shape_func->outputs.size()))); - Array all_prim_shape_fn_vars; - for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { - CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; - all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + auto device_copy = IsDeviceCopy(func); + if (std::get<0>(device_copy)) { + // Record that device copy source and destination devices so the device planner can + // still follow along. + auto source_device = std::get<1>(device_copy); + auto dst_device = std::get<2>(device_copy); + call_lowered_attrs->metadata.Set("source_device", tvm::Integer(source_device)); + call_lowered_attrs->metadata.Set("dst_device", tvm::Integer(dst_device)); } - tir_call_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); - } - return {lowered_func->prim_fn_var, Attrs(tir_call_attrs)}; + call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + + if (IsDynamic(func->ret_type)) { + // Also lower the dynamic shape function. + // Shape function keys use the underlying primitive function as their 'function', + // but the generic 'cpu' target as the target since all shape functions run + // on the host cpu irrespective of where the primitive runs. + // TODO(mbs): Cleanup target handling. + Target shape_target("llvm"); + VLOG(1) << "lowering to target '" << shape_target->str() + << "' for dynamic shape function for primitive"; + CCacheKey shape_key(func, shape_target); + CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); + // Capture the shape function's global var and parameters 'states' in call + // annotations so calling convention can be recovered. + // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available. + // The way the shape function calling convention is derived and passed to call sites + // via the 'parameter states' could be improved. + call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs->metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs->metadata.Set( + "prim_shape_fn_num_outputs", + Integer(static_cast(lowered_shape_func->outputs.size()))); + Array all_prim_shape_fn_vars; + for (auto prim_shape_fn : lowered_shape_func->funcs->functions) { + CHECK(prim_shape_fn.second.as()) << "must be a prim fn"; + all_prim_shape_fn_vars.push_back(prim_shape_fn.first); + } + call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + } + return CallLowered(lowered_func->prim_fn_var, visited_args, Attrs(call_lowered_attrs), + type_args, span); + } } std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { @@ -523,7 +575,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { BaseFunc prim_func = ResolveToPrimitive(new_value); if (prim_func.defined() && !prim_func->IsInstance()) { - // Remember let var is bound to (possibly indirectly) to a non-tir primitive. + // Remember let var is bound to (possibly indirectly) a non-tir primitive. Function func = Downcast(prim_func); primitive_functions_.emplace(var, func); } @@ -539,8 +591,21 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } + Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Nothing to lower inside primitive functions. + return GetRef(function_node); + } else { + return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node); + } + } + Expr DeviceAwareVisitExpr_(const CallNode* call_node) override { + // Passes before lowering might insert a call_lowered to call a function that has already + // been lowered. Therefore we might see call_lowered ops here, but we don't need to do anything + // because ResolveToPrimitive returns null for all calls where the call_node->op is an OpNode Call call = GetRef(call_node); + // Look for (indirect) calls to primitives. BaseFunc prim_func = ResolveToPrimitive(call_node->op); if (!prim_func.defined()) { @@ -551,10 +616,16 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { return ExprMutator::VisitExpr_(call_node); } + // Similarly transform arguments. + Array args; + for (const auto& arg : call_node->args) { + args.push_back(VisitExpr(arg)); + } + // Already lowered by other means so we don't need to mutate - // the call + // the call but we do need to mutate the arguments if (prim_func->IsInstance()) { - return std::move(call); + return Call(call_node->op, args, call_node->attrs); } // Find the desired target device. @@ -568,20 +639,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // TODO(mbs): Replace device_type with target so this lookup is unnecessary. target = GetTargetFromInteger(device_type, targets_); } - - // Lower the primitive function for that target. - Function func = Downcast(prim_func); - std::pair pair = LowerFunction(func, target); - - // Similarly transform arguments. - Array args; + Array visited_args; for (const auto& arg : call_node->args) { - args.push_back(VisitExpr(arg)); + visited_args.push_back(VisitExpr(arg)); } - - // Replace with direct call to lowered primitive, and attach annotations to record calling - // convention. - return Call(pair.first, args, pair.second); + // Lower the primitive function for that target. + Function func = Downcast(prim_func); + return MakeLoweredCall(func, visited_args, call_node->type_args, call_node->span, target); } IRModule module_; @@ -857,8 +921,6 @@ void UpdateFunctionMetadata(Function relay_func, IRModule LowerTE(const IRModule& module, TargetMap targets, const String& module_name, std::function process_fn) { - DLOG(INFO) << "lowering module:\n" << PrettyPrint(module); - TECompiler compiler; auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module); diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 248fd40f98eb..d0401e9605f7 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -127,6 +127,7 @@ class TECompiler : public ObjectRef { explicit TECompiler(ObjectPtr n) : ObjectRef(n) {} TECompilerNode* operator->() { return static_cast(get_mutable()); } using ContainerType = TECompilerNode; + TVM_DLL static TECompiler& Global(); }; /*! @@ -147,7 +148,7 @@ void UpdateFunctionMetadata(Function relay_func, * \param dev_type * \return Target */ -Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets); /*! * \brief Update the "main" control function's metadata @@ -193,7 +194,7 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \returns The pass which lowers primative functions to TIR + * \returns The pass which lowers primitive functions to TIR */ transform::Pass LowerTEPass(TargetMap targets, const String& module_name, std::function process_fn); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index ec87cfc98931..266bd719545a 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -70,7 +70,8 @@ CCacheKey::CCacheKey(Function source_func, Target target) { CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array inputs, tvm::Array outputs, te::Schedule schedule, - tvm::Array shape_func_param_states, IRModule funcs) { + tir::PrimFunc prim_func, tvm::Array shape_func_param_states, + IRModule funcs) { auto n = make_object(); n->target = target; n->prim_fn_var = prim_fn_var; @@ -117,11 +118,12 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } - CachedFunc Create(const Function& prim_func, std::function renamer) { + CachedFunc Create(const Function& relay_func, std::function renamer) { Array fn_inputs; - for (Var param : prim_func->params) { + for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); @@ -131,9 +133,11 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(prim_func->body); + auto outputs = this->VisitExpr(relay_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); @@ -149,7 +153,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator prim_fn_name = renamer(prim_fn_name); } auto prim_fn_var = GlobalVar(prim_fn_name); - prim_fn_var->checked_type_ = prim_func->checked_type(); + prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships // between inputs and outputs, and those should not be scheduled. @@ -161,7 +165,8 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - te::Schedule schedule; + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. if (anchor_attrs_.as() == nullptr && create_schedule_) { if (use_auto_scheduler_) { @@ -174,20 +179,39 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator schedule = Downcast(obj); } } + if (use_meta_schedule_) { + const auto* f_create_func = runtime::Registry::Get("te.CreatePrimFuncFromOutputs"); + const auto* f_meta_schedule = + runtime::Registry::Get("meta_schedule.MetaScheduleContextQueryInsideWithScope"); + ICHECK(f_create_func) << "te.CreatePrimFuncFromOutputs is not registered"; + ICHECK(f_meta_schedule) + << "meta_schedule.MetaScheduleContextQueryInsideWithScope is not registered"; + prim_func = (*f_create_func)(tensor_outs); + Optional opt_mod_or_base_func = + (*f_meta_schedule)(prim_fn_name, IRModule({{GlobalVar(prim_fn_name), relay_func}}), + Array{IRModule({{GlobalVar(prim_fn_name), prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined()) { + if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); + if (schedule.defined()) { + for (const auto& scalar : scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } } } } - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {}); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}); } Array VisitExpr_(const VarNode* op) final { @@ -334,6 +358,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator std::ostringstream readable_name_stream_; Array scalars_; bool use_auto_scheduler_; + bool use_meta_schedule_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. const Op& device_copy_op_; @@ -394,6 +419,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> // Generate a name. auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; truncated_name << candidate_name.substr(0, kMaxFuncNameLength); @@ -446,8 +473,8 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> std::unordered_map binds; IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds); - return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states, - ir_module); + return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, + shape_func_param_states, ir_module); } Array VisitExpr(const Expr& expr) final { @@ -462,8 +489,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const VarNode* var_node) final { auto var = GetRef(var_node); - auto it = param_states_.find(var); - if (it == param_states_.end()) { + auto it = param_arg_map_.find(var); + if (it != param_arg_map_.end()) { + // This var is a parameter of a nested function. Visit the corresponding argument in the + // function call site. + return VisitExpr(it->second); + } + if (param_states_.find(var) == param_states_.end()) { LOG(FATAL) << "Unexpected free variable " << var->name_hint(); return {}; } else { @@ -538,6 +570,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } Array VisitExpr_(const CallNode* call_node) final { + if (auto* func = call_node->op.as()) { + for (size_t i = 0; i < func->params.size(); ++i) { + param_arg_map_[func->params[i]] = call_node->args[i]; + } + return VisitExpr(func->body); + } static auto fshape_func = Op::GetAttrMap("FShapeFunc"); static auto tshape_data_dependent = Op::GetAttrMap("TShapeDataDependent"); ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; @@ -597,7 +635,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; + LOG(FATAL) << "Nested functions are not allowed to be visited."; return Array(); } @@ -640,6 +678,10 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> std::vector data_dependents_per_input_; /*! \brief Scalars used in the shape function */ Array scalars_; + /*! \brief Map from parameters of a nested function to corresponding arguments in a function + * call site. + */ + std::unordered_map param_arg_map_; }; CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 47ba96b2c77e..2171880fd6a5 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -62,7 +62,6 @@ struct LoweredOutputNode : public Object { v->Visit("outputs", &outputs); v->Visit("implementation", &implementation); } - static constexpr const char* _type_key = "relay.LoweredOutput"; TVM_DECLARE_FINAL_OBJECT_INFO(LoweredOutputNode, Object); }; @@ -130,16 +129,18 @@ class CCacheKey : public ObjectRef { /*! \brief Node container to represent a cached function. */ struct CachedFuncNode : public Object { - /* \brief compiled target */ + /*! \brief compiled target */ tvm::Target target; /*! \brief Primitive Function Name */ GlobalVar prim_fn_var; - /* \brief The inputs to the function */ + /*! \brief The inputs to the function */ tvm::Array inputs; - /* \brief The outputs to the function */ + /*! \brief The outputs to the function */ tvm::Array outputs; /*! \brief The schedule to the function */ te::Schedule schedule; + /*! \brief The TIR function if lowering in the meta schedule path */ + Optional prim_func; /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; /*! \brief The lowered functions to support the function. */ @@ -151,6 +152,7 @@ struct CachedFuncNode : public Object { v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); v->Visit("schedule", &schedule); + v->Visit("prim_func", &prim_func); v->Visit("funcs", &funcs); v->Visit("shape_func_param_states", &shape_func_param_states); } @@ -162,7 +164,7 @@ struct CachedFuncNode : public Object { class CachedFunc : public ObjectRef { public: CachedFunc(tvm::Target target, GlobalVar prim_fn_name, tvm::Array inputs, - tvm::Array outputs, te::Schedule schedule, + tvm::Array outputs, te::Schedule schedule, tir::PrimFunc prim_func, tvm::Array shape_func_param_states, IRModule funcs = IRModule(Map({}))); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 6d59b858927c..16cbe0e8dbca 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -327,7 +327,7 @@ inline relay::Function BindParamsByName( for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { - repeat_var.insert(arg); + repeat_var.insert(name_dict[name]); } else { name_dict[name] = arg; } @@ -428,11 +428,11 @@ inline bool IsAutoSchedulerEnabled() { } /*! - * \brief Return whether the compile engine cache is disabled in the pass context. + * \brief Return whether the meta schedule is enabled in the pass context. */ -inline bool IsCompileEngineCacheDisabled() { +inline bool IsMetaScheduleEnabled() { return transform::PassContext::Current() - ->GetConfig("relay.backend.disable_compile_engine_cache", Bool(false)) + ->GetConfig("relay.backend.use_meta_schedule", Bool(false)) .value(); } @@ -446,7 +446,7 @@ inline bool IsCompileEngineCacheDisabled() { * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPassPrefix(const Map& targets, bool is_vm); +Array GetPassPrefix(const TargetMap& targets, bool is_vm); /*! \brief Target hash function */ struct TargetStrHash { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3c1cd81274f..02477d05673b 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,7 +80,6 @@ namespace vm { using namespace tvm::runtime; using namespace tvm::runtime::vm; using namespace relay::transform; -using namespace tec; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); @@ -251,7 +250,7 @@ int GetFallbackDevice() { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetMap targets, Target target_host) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), @@ -304,7 +303,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } VisitExpr(func); } - return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type); + std::vector params_device_type_index; + params_device_type_index.reserve(params_device_type.size()); + for (auto device_type : params_device_type) { + params_device_type_index.push_back(static_cast(device_type)); + } + return VMFunction(var->name_hint, params_, instructions_, registers_num_, + params_device_type_index); } /*! \brief Attrs objects for each op. */ @@ -317,7 +322,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t NewRegister() { return registers_num_++; } inline void Emit(const Instruction& instr) { - VLOG(1) << "VMCompiler::Emit: instr=" << instr; + VLOG(2) << "VMCompiler::Emit: instr=" << instr; ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { case Opcode::AllocADT: @@ -458,7 +463,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void EmitShapeFunc(Function func, Array inputs, Array outputs) { // Lower shape function - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = context_->compiler->LowerShapeFunc(key); int op_index = -1; // pick the only function inside the context @@ -534,7 +539,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } } - CCacheKey key(func, target); + tec::CCacheKey key(func, target); auto mangle_fn = [](String name) { return name; }; auto cfunc = context_->compiler->Lower(key, mangle_fn); // <<<< one-func-at-a-time lowering @@ -613,8 +618,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } }) .Match("memory.alloc_storage", - [this, call_node](const Array& args, const Attrs& attrs, - const Array& type_arg) { + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { ICHECK_EQ(args.size(), 2); // Compute the size of the allocation. this->VisitExpr(args[0]); @@ -703,7 +707,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto global = GetRef(global_node); auto it = context_->global_map.find(global); ICHECK(it != context_->global_map.end()); - VLOG(1) << "VisitExpr_: generating invoke for " << global->name_hint + VLOG(2) << "VisitExpr_: generating invoke for " << global->name_hint << " with func_index=" << it->second; // TODO(tvm-team): @@ -904,7 +908,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { +void VMCompiler::Lower(IRModule mod, const tvm::TargetMap& targets, + const tvm::Target& target_host) { exec_ = make_object(); targets_ = targets; target_host_ = target_host; @@ -941,12 +946,6 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe } } -#if USE_RELAY_DEBUG - for (auto vm_func : exec_->functions) { - VLOG(1) << vm_func << "-------------"; - } -#endif // USE_RELAY_DEBUG - // populate constants for (auto data : context_.constants) { exec_->constants.push_back(data); @@ -967,10 +966,16 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } +#if USE_RELAY_DEBUG + for (const auto& vm_func : exec_->functions) { + VLOG(1) << vm_func << "-------------"; + } +#endif // USE_RELAY_DEBUG + backend::UpdateAutoSchedulerOpWeights(context_.compiler); } -transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { +transform::Sequential MemoryOpt(tvm::Target host_target, tvm::TargetMap targets) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; @@ -1016,9 +1021,10 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { return transform::Sequential(pass_seqs); } -IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, +IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets_arg, const Target& target_host_arg) { - TargetsMap targets = targets_arg; + VLOG_CONTEXT << "VMCompiler::OptimizeModule"; + TargetMap targets = targets_arg; Target target_host = target_host_arg; CheckAndUpdateHostConsistency(&targets, &target_host); if (params_.size()) { diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index af3c5bccbeff..5b51d7821d78 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,7 +62,6 @@ using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; -using TargetsMap = Map; struct VMCompilerContext { // The module context for the compilation @@ -111,7 +110,7 @@ class VMCompiler : public runtime::ModuleNode { * to target mapping. For homogeneous compilation, it is a singleton build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); + void Lower(IRModule mod, const TargetMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); @@ -127,7 +126,7 @@ class VMCompiler : public runtime::ModuleNode { * * \return The optimized IRModule. */ - IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host); + IRModule OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host); /*! * \brief Populate the global function names in a map where the value is used @@ -137,7 +136,7 @@ class VMCompiler : public runtime::ModuleNode { protected: /*! \brief Target devices. */ - TargetsMap targets_; + TargetMap targets_; /*! \brief Target host device. */ tvm::Target target_host_; /*! \brief Global shared meta data */ diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 6924f2598f6f..674424872251 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -87,7 +87,7 @@ struct PrimitiveInliner : ExprMutator { // in w(...) while ((var_node = op.as())) { auto var = GetRef(var_node); - DLOG(INFO) << "Var: " << var << std::endl; + VLOG(1) << "Var: " << var << std::endl; auto it = var_map.find(GetRef(var_node)); if (it != var_map.end()) { op = it->second; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3b3c8797d7f2..c7a81f9f0f03 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -115,8 +115,6 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s n->attrs = std::move(attrs); n->type_args = std::move(type_args); n->span = std::move(span); - n->saved_deleter_ = n->deleter_; - n->deleter_ = CallNode::Deleter_; data_ = std::move(n); } @@ -282,6 +280,9 @@ inline void Dismantle(const Expr& expr) { if (auto* op = const_cast(node.as())) { op->args = Array(); } + if (auto* op = const_cast(node.as())) { + op->body = Expr(); + } } // eject stack.pop(); @@ -308,6 +309,11 @@ inline void Dismantle(const Expr& expr) { if (op->tuple.use_count() < 2) { fpush_to_stack(op->tuple); } + } else if (const LetNode* op = node.as()) { + // do not process let if used elsewhere + if (op->body.use_count() < 2) { + fpush_to_stack(op->body); + } } } } @@ -336,5 +342,28 @@ void CallNode::Deleter_(Object* ptr) { auto c = GetRef(p); } +/* + * Non-recursive destructor + */ +Let::~Let() { + // attempt to dismantle if referenced one or zero times + if (this->use_count() < 2) { + if (this->as() && this->as()->body.defined()) { + Dismantle(*this); + } + } +} + +/* + * LetNode's deleter + */ +void LetNode::Deleter_(Object* ptr) { + auto p = reinterpret_cast(ptr); + // resore original deleter + p->deleter_ = p->saved_deleter_; + // create Let reference in order to invoke ~Let + auto c = GetRef(p); +} + } // namespace relay } // namespace tvm diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index 8b00839cda33..27b61333c9eb 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -76,6 +76,16 @@ Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { // by the function's attributes. return expr; } + OnDeviceProps props = GetOnDeviceProps(expr); + if (props.body.defined()) { + // Don't nest on_devices. + // If the inner and outer device types differ then we need to be careful: + // - If the inner on_device is_fixed then it disagrees with the outer. + // - If the outer on_device is_fixed then it implies a hidden device_copy + // Otherwise just use the inner device type and ignore the outer. + ICHECK(props.device_type == device_type || (!is_fixed && !props.is_fixed)); + return OnDevice(props.body, device_type, is_fixed || props.is_fixed); + } return OnDevice(expr, device_type, is_fixed); } diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc new file mode 100644 index 000000000000..9485b72d8374 --- /dev/null +++ b/src/relay/op/call/call.cc @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/call/call.cc + * \brief Operators for calling lowered functions. + */ + +#include "./call.h" + +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(CallLoweredAttrs); + +// call_lowered +bool CallLoweredRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Types = [func, call_args, ret_type] + if (types.size() != 3u) { + return false; + } + const auto* func_type = types[0].as(); + if (!func_type) { + return false; + } + + const auto* tuple_type_node = types[1].as(); + if (!tuple_type_node) { + return false; + } + + // Constraint to ensure function arguments are the same type as the inputs to the function (modulo + // the Tuple wrapper) + reporter->Assign(GetRef(tuple_type_node), TupleType(func_type->arg_types, {})); + // Constraint to ensure the output of call_lowered is the same as the function's return type + reporter->Assign(types[2], func_type->ret_type); + return true; +} + +const Op& CallLoweredOp() { return Op::Get("call_lowered"); } + +Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span) { + // Right now, call_lowered only supports func being a global var pointing to the lowered + // function. + ICHECK(func.as()) + << "Function to call should be GlobalVarNode, but got " << func->GetTypeKey(); + ICHECK(attrs.as()) + << "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey(); + return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs), + std::move(type_args), std::move(span)); +} + +TVM_REGISTER_GLOBAL("relay.op.call_lowered") + .set_body_typed([](Expr func, Array inputs, Attrs attrs, Array type_args, + Span span) { + const TupleNode* tuple_node = inputs.as(); + return CallLowered(func, tuple_node->fields, attrs, type_args, span); + }); + +RELAY_REGISTER_OP("call_lowered") + .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("func", "Function", "The lowered function to call.") + .add_argument("call_args", "Tuple", "The input tensors.") + .add_type_rel("CallLoweredRel", CallLoweredRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { + ICHECK(call_node->op == CallLoweredOp()) + << "GetCallLoweredProps expects the op to be call_lowered. "; + ICHECK(call_node->args.size() == 2) << "Expected call_lowered to have 2 arguments. "; + const auto* function = call_node->args[0].as(); + ICHECK(function) << "Expected first arg to call_lowered to be a GlobalVar. "; + + const auto* tuple_args = call_node->args[1].as(); + ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple. "; + + ICHECK(call_node->attrs.defined()) << "Attributes for call_lowered should be defined!"; + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " + << call_node->attrs->GetTypeKey(); + return CallLoweredProps{std::move(GetRef(function)), std::move(tuple_args->fields), + std::move(*attrs)}; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h new file mode 100644 index 000000000000..381be6724e0d --- /dev/null +++ b/src/relay/op/call/call.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/call/call.h + * \brief Operators for calling lowered functions. + */ +#ifndef TVM_RELAY_OP_CALL_CALL_H_ +#define TVM_RELAY_OP_CALL_CALL_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Helper to construct a Relay call with the call_lowered op. + * \param func Lowered function to call with call_lowered. + * \param inputs Arguments to be passed to the function. + * \param attrs Function attributes, should be TIRCallAttrs. + * \param type_args Type arguments for the call. + * \param span TVM span for propogating debugging info. + * \return + */ +Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span); + +/*! + * \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to + * Registry::Get. + */ +const Op& CallLoweredOp(); + +/*! + * \brief Lowered function and the arguments to call it with. + */ +struct CallLoweredProps { + /*! \brief Global variable pointing to the lowered function. */ + GlobalVar lowered_func; + /*! \brief Array of the arguments to call lowered_func with. */ + Array arguments; + /*! \brief Arguments from the call_lowered op. */ + CallLoweredAttrs attrs; +}; + +/*! + * \brief Helper to extract the lowered function and its arguments from Call("call_lowered", ...). + * Will fail if called on a Call whose op is not "call_lowered" \param call_node CallNode that we + * want to get the function and its arguments from. + */ +CallLoweredProps GetCallLoweredProps(const CallNode* call_node); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_CALL_CALL_H_ diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc new file mode 100644 index 000000000000..5b4900edc74b --- /dev/null +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/binary_elementwise.cc + * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U NPU. + */ +#include + +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU binary elementwise operators */ +struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode { + String operator_type; + double ifm_scale; + int ifm_zero_point; + double ifm2_scale; + int ifm2_zero_point; + double ofm_scale; + int ofm_zero_point; + IndexExpr ifm_channels; + IndexExpr ifm2_channels; + bool reversed_operands; + String activation; + int clip_min; + int clip_max; + String ifm_layout; + String ifm2_layout; + String ofm_layout; + String ofm_dtype; + + TVM_DECLARE_ATTRS(EthosuBinaryElementwiseAttrs, "relay.attrs.EthosuBinaryElementwiseAttrs") { + TVM_ATTR_FIELD(operator_type) + .describe( + "The type of the binary elementwise operator." + "'ADD'" + "'SUB'" + "'MUL'" + "'MIN'" + "'MAX'" + "'SHR'" + "'SHL'"); + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm2_scale) + .describe("The quantization scale for the Input Feature Map tensor 2."); + TVM_ATTR_FIELD(ifm2_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor 2."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ifm_channels).describe("The number of the Input Feature Map channels."); + TVM_ATTR_FIELD(ifm2_channels).describe("The number of the Input Feature Map 2 channels."); + TVM_ATTR_FIELD(reversed_operands) + .describe("True if IFM2 is the first operand and IFM is the second operand.") + .set_default(false); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function." + "Available activations for activation type:" + "{int8, uint8}: 'NONE', 'CLIP', 'TANH', 'SIGMOID', 'LUT'" + "{int32}: 'NONE'") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(ifm_layout) + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ifm2_layout) + .describe("The layout of the Input Feature Map tensor 2. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_layout) + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_dtype).describe( + "The Output Feature Map tensor type." + "MUL, ADD, SUB {IFM}->{OFM}:" + " {uint8, int8 int32} -> {uint8, int8, int32}, any pairing" + "MAX, MIN:" + " IFM and OFM must be of the same type, one of:" + " {int8, uint8}" + "SHR {IFM}->{OFM}:" + " {int32}->{int8, uint8, int32}, any pairing" + "SHL:" + " {int32}->{int32} only"); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs); + +bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const int ifm_index = 0; + const int ifm2_index = 1; + const int result_index = 3; + ICHECK_EQ(types.size(), result_index + 1); + + const auto* ifm = types[ifm_index].as(); + const auto* ifm2 = types[ifm2_index].as(); + if (ifm == nullptr) return false; + if (ifm2 == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr."; + + String operator_type = param->operator_type; + auto ifm_dtype = ifm->dtype; + auto ifm2_dtype = ifm2->dtype; + DataType ofm_dtype; + + if (param->ofm_dtype == "int8") { + ofm_dtype = DataType::Int(8); + } else if (param->ofm_dtype == "uint8") { + ofm_dtype = DataType::UInt(8); + } else if (param->ofm_dtype == "int32") { + ofm_dtype = DataType::Int(32); + } + + if (ifm_dtype != ifm2_dtype) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << "type for ifm2 be the same of ifm but was " << ifm2_dtype + << " instead of " << ifm_dtype); + return false; + } + + if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && + ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ifm but was " << ifm_dtype); + return false; + } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + } else if (operator_type == "MIN" || operator_type == "MAX") { + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) for ifm but was " << ifm_dtype); + return false; + } + if (ifm_dtype != ofm_dtype) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type + << " type for ofm be the same of ifm but was " << ofm_dtype + << " instead of " << ifm_dtype); + return false; + } + } else if (operator_type == "SHR") { + if (ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ifm but was " + << ifm_dtype); + return false; + } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + } else if (operator_type == "SHL") { + if (ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ifm but was " + << ifm_dtype); + + return false; + } + if (ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ofm but was " + << ofm_dtype); + return false; + } + } else { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or " + << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type); + return false; + } + + // Assign ofm type + auto ofm_shape = EthosuInferBinaryElementwiseOutputShape(ifm->shape, param->ifm_layout, + param->ofm_layout, param->ifm_channels); + reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype)); + return true; +} + +Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_type, + double ifm_scale, int ifm_zero_point, double ifm2_scale, + int ifm2_zero_point, double ofm_scale, int ofm_zero_point, + IndexExpr ifm_channels, IndexExpr ifm2_channels, + bool reversed_operands, String activation, int clip_min, + int clip_max, String ifm_layout, String ifm2_layout, + String ofm_layout, String ofm_dtype) { + auto attrs = make_object(); + + attrs->operator_type = std::move(operator_type); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->ifm2_scale = ifm2_scale; + attrs->ifm2_zero_point = ifm2_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->ifm_channels = std::move(ifm_channels); + attrs->ifm2_channels = std::move(ifm2_channels); + attrs->reversed_operands = reversed_operands; + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->ifm_layout = std::move(ifm_layout); + attrs->ifm2_layout = std::move(ifm2_layout); + attrs->ofm_layout = std::move(ofm_layout); + attrs->ofm_dtype = std::move(ofm_dtype); + + static const Op& op = Op::Get("contrib.ethosu.binary_elementwise"); + return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise") + .set_body_typed(MakeEthosuBinaryElementwise); + +RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise") + .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise operator. + +This Relay operator corresponds to the hardware-implemented quantized +binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC +or NHCWB16 format for the inputs data (input feature maps, or IFMs). + +Reference: https://developer.arm.com/documentation/102420/0200/ + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ofm**: (1, ofm_height, ofm_width, ifm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("ifm2", "Tensor", "The Input Feature Map tensor 2 (IFM2).") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuBinaryElementwise", EthosuBinaryElementwiseRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index bdda81bc7708..bdaa9da52618 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -32,6 +32,24 @@ namespace op { namespace contrib { namespace ethosu { +Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, + String ifm_layout, String ofm_layout, + IndexExpr ofm_channels) { + // In the case of NHCWB16, convert the ifm shape to NHW (C not required for this function) + if (ifm_layout == "NHCWB16") { + ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]}; + } + Array oshape({ifm_shape[0], ifm_shape[1], ifm_shape[2], ofm_channels}); + + // If the ofm is NHCWB16, convert the layout + if (ofm_layout == "NHCWB16") { + int channel_bricks = 1 + (oshape[3].as()->value - 1) / 16; + oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16}; + } + + return oshape; +} + Array EthosuInferKernelOutput(Array ifm_shape, String ifm_layout, String ofm_layout, Array kernel_shape, IndexExpr ofm_channels, Array dilation, diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index b5377e6e8bdf..574fb91181ef 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -33,6 +33,17 @@ namespace op { namespace contrib { namespace ethosu { +/*! \brief Infer the output tensor shape for binary elementwise operators. + * \param ifm_shape The shape of Input Feature Map. + * \param ifm_layout The layout of the IFM (NHWC or NHCWB16). + * \param ofm_layout The layout of the OFM (NHWC or NHCWB16). + * \param ofm_channels The number of Output Feature Map channels. + * \return The shape of the output tensor. + */ +Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, + String ifm_layout, String ofm_layout, + IndexExpr ofm_channels); + /*! \brief Infer the output tensor shape for convolution and pooling operators. * \param ifm_shape The shape of Input Feature Map. * \param ifm_layout The layout of the IFM (NHWC or NHCWB16). diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index bad10bf66f3a..9471f88ac376 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -70,7 +70,7 @@ struct EthosuConv2DAttrs : public tvm::AttrsNode { .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") .set_default(NullValue>()); TVM_ATTR_FIELD(ofm_channels) - .describe("The number of OFM channels.") + .describe("The number of the Output Feature Map channels.") .set_default(NullValue()); TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) @@ -123,12 +123,28 @@ bool EthosuConv2DRel(const Array& types, int num_inputs, const Attrs& attr if (ifm == nullptr || weight == nullptr) return false; const auto* param = attrs.as(); CHECK(param != nullptr) << "EthosuConv2DAttrs cannot be nullptr."; - CHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_conv2d type(uint8) or type(int8) for ifm but was " << ifm->dtype; - CHECK(weight->dtype == DataType::UInt(8) || weight->dtype == DataType::Int(8)) - << "Expected ethosu_conv2d type(uint8) or type(int8) for weight but was " << weight->dtype; - CHECK(scale_bias->dtype == DataType::UInt(8)) - << "Expected ethosu_conv2d type(uint8) for scale_bias but was " << scale_bias->dtype; + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d input data type " + << "of type(uint8) or type(int8) but was " << ifm->dtype); + return false; + } + + if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d weight data type " + << "of type(uint8) or type(int8) but was " << weight->dtype); + return false; + } + + if (scale_bias->dtype != DataType::UInt(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_conv2d scale bias data type " + << "of type(uint8) but was " << scale_bias->dtype); + return false; + } // The scale_bias should be provided as a tensor of size {ofm_channels, 10} reporter->Assign(types[2], TensorType({weight->shape[0], 10}, DataType::UInt(8))); @@ -179,7 +195,7 @@ RELAY_REGISTER_OP("contrib.ethosu.conv2d") .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized convolution operator. This Relay operator corresponds to the hardware-implemented quantized -convolution operation found on Ethos(TM)-U NPUs. It accepts either NHWC +convolution operation found on Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data (Input Feature Map, or IFM) and OHWI format for the kernel weights. @@ -201,7 +217,7 @@ of type uint8. For more detail, refer to the Technical Reference Manual linked a .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") .add_argument("weight", "Tensor", "The weight tensor.") .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") - .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'.") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'.") .set_support_level(11) .add_type_rel("EthosuConv2D", EthosuConv2DRel); diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index fa73645d45de..7918285ce1b7 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -123,15 +123,30 @@ bool EthosuDepthwiseConv2DRel(const Array& types, int num_inputs, const At const auto* param = attrs.as(); ICHECK(param != nullptr) << "EthosuDepthwiseConv2DAttrs cannot be nullptr."; - ICHECK(ifm->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for ifm but was " - << ifm->dtype; - ICHECK(weight->dtype == DataType::UInt(8) || ifm->dtype == DataType::Int(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) or type(int8) for weight but was " - << weight->dtype; - ICHECK(scale_bias->dtype == DataType::UInt(8)) - << "Expected ethosu_depthwise_conv2d type(uint8) for scale_bias but was " - << scale_bias->dtype; + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d input data type " + << "of type(uint8) or type(int8) but was " << ifm->dtype); + return false; + } + + if (weight->dtype != DataType::UInt(8) && weight->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d weight data type " + << "of type(uint8) or type(int8) but was " << weight->dtype); + return false; + } + + if (scale_bias->dtype != DataType::UInt(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_depthwise_conv2d scale bias data type " + << "of type(uint8) but was " << scale_bias->dtype); + return false; + } // Collect the ifm, weight and ofm tensors for using in the inference function Array tensor_types = {types[0], types[1], types[4]}; @@ -186,7 +201,7 @@ RELAY_REGISTER_OP("contrib.ethosu.depthwise_conv2d") .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized depthwise operator. This Relay operator corresponds to the hardware-implemented quantized -depthwise operation found on Ethos(TM)-U NPUs. It accepts either NHWC or NHCWB16 format +depthwise operation found on Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data (input feature map, or IFM) and OHWI format for the kernel weights. - **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) @@ -201,7 +216,7 @@ for the input data (input feature map, or IFM) and OHWI format for the kernel we .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") .add_argument("weight", "Tensor", "The weight tensor.") .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") - .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") .set_support_level(11) .add_type_rel("EthosuDepthwiseConv2D", EthosuDepthwiseConv2DRel); diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc new file mode 100644 index 000000000000..bcf54fbd4a2d --- /dev/null +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/pooling.cc + * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU. + */ +#include + +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU pooling operator */ +struct EthosuPoolingAttrs : public tvm::AttrsNode { + String pooling_type; + double ifm_scale; + int ifm_zero_point; + double ofm_scale; + int ofm_zero_point; + Array pool_shape; + IndexExpr ofm_channels; + Array strides; + Array padding; + String activation; + int clip_min; + int clip_max; + String upscale; + String ifm_layout; + String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuPoolingAttrs, "relay.attrs.EthosuPoolingAttrs") { + TVM_ATTR_FIELD(pooling_type) + .describe("The type of the pooling. 'AVG' - average pool, 'MAX' - max pool."); + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(pool_shape) + .describe("The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).") + .set_default(NullValue >()); + TVM_ATTR_FIELD(ofm_channels) + .describe(" The number of the Output Feature Map channels.") + .set_default(NullValue()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The 2 dimensional strides as (stride_height, stride_width)."); + TVM_ATTR_FIELD(padding) + .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right).") + .set_default(Array({0, 0, 0, 0})); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(upscale) + .describe( + "The 2x2 upscaling mode to apply to the Input Feature Map tensor. " + "'NONE' - no upscaling. " + "'NEAREST' - upscale using nearest neighbour. " + "'ZEROS' - upscale using zeros.") + .set_default("NONE"); + TVM_ATTR_FIELD(ifm_layout) + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_layout) + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuPoolingAttrs); + +bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + int ifm_index = 0; + int result_index = 2; + ICHECK_EQ(types.size(), result_index + 1); + + const auto* ifm = types[ifm_index].as(); + if (ifm == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuPoolingAttrs cannot be nullptr."; + + if (param->pooling_type != "AVG" && param->pooling_type != "MAX") { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected pooling_type 'AVG' or 'MAX' but was " + << param->pooling_type); + return false; + } + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: Expected pool type(uint8) or type(int8) for ifm but was " + << ifm->dtype); + return false; + } + + // Assign ofm type + auto ofm_shape = EthosuInferKernelOutput( + ifm->shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels, + Array({1, 1}), param->strides, param->padding); + reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype)); + return true; +} + +Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale, + int ifm_zero_point, double ofm_scale, int ofm_zero_point, + Array pool_shape, IndexExpr ofm_channels, + Array strides, Array padding, String activation, + int clip_min, int clip_max, String upscale, String ifm_layout, + String ofm_layout) { + auto attrs = make_object(); + attrs->pooling_type = std::move(pooling_type); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->pool_shape = std::move(pool_shape); + attrs->ofm_channels = std::move(ofm_channels); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->upscale = std::move(upscale); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + static const Op& op = Op::Get("contrib.ethosu.pooling"); + return Call(op, {ifm, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_pooling").set_body_typed(MakeEthosuPooling); + +RELAY_REGISTER_OP("contrib.ethosu.pooling") + .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized pooling operator. + +This Relay operator corresponds to the hardware-implemented quantized +pooling operation found on Ethos(TM)-U NPU. It accepts either NHWC +or NHCWB16 format for the input data (input feature map, or IFM). + +Reference: https://developer.arm.com/documentation/102420/0200/ + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuPooling", EthosuPoolingRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index dce89aa91b65..9106b95c9217 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -24,6 +24,7 @@ #include "./device_copy.h" +#include #include #include #include @@ -31,6 +32,8 @@ #include #include "../../transforms/infer_layout_utils.h" +#include "../annotation/annotation.h" +#include "../call/call.h" #include "../type_relations.h" namespace tvm { @@ -86,6 +89,7 @@ on different devices. return {topi::identity(inputs[0])}; }); +// Get device copy props for original device copy op DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { if (call_node->op == DeviceCopyOp()) { ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument"; @@ -103,6 +107,19 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { } else { return {call_node->args[0], src_dev_type, dst_dev_type}; } + } else if (call_node->op == CallLoweredOp()) { + /* Get device props for a TIR function */ + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + + if (call_lowered_props.attrs.metadata.count("source_device") == 1 && + call_lowered_props.attrs.metadata.count("dst_device") == 1) { + ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; + return {call_lowered_props.lowered_func, + static_cast( + Downcast(call_lowered_props.attrs.metadata["source_device"])->value), + static_cast( + Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; + } } return {}; } diff --git a/src/relay/op/vm/vm.h b/src/relay/op/vm/vm.h index 802c8100125a..68d25b097bce 100644 --- a/src/relay/op/vm/vm.h +++ b/src/relay/op/vm/vm.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_OP_VM_VM_H_ #define TVM_RELAY_OP_VM_VM_H_ -#include "tvm/relay/expr.h" +#include namespace tvm { namespace relay { diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index 23759a52ec41..79d5549d659a 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -49,7 +49,8 @@ static inline Array get_shape(const Type& type) { static inline int32_t GetQmin(const DataType& dtype) { ICHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { - auto* min_value = tir::as_const_int(tvm::min_value(dtype)); + auto min_value_expr = tvm::min_value(dtype); + auto* min_value = tir::as_const_int(min_value_expr); ICHECK(min_value != nullptr); return static_cast(min_value[0]); } else { @@ -61,7 +62,8 @@ static inline int32_t GetQmin(const DataType& dtype) { static inline int32_t GetQmax(const DataType& dtype) { ICHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { - auto* max_value = tir::as_const_int(tvm::max_value(dtype)); + auto max_value_expr = tvm::max_value(dtype); + auto* max_value = tir::as_const_int(max_value_expr); ICHECK(max_value != nullptr); return static_cast(max_value[0]); } else { diff --git a/src/relay/transforms/auto_scheduler_layout_rewrite.cc b/src/relay/transforms/auto_scheduler_layout_rewrite.cc index 7a86af8aeffa..c538dac048b3 100644 --- a/src/relay/transforms/auto_scheduler_layout_rewrite.cc +++ b/src/relay/transforms/auto_scheduler_layout_rewrite.cc @@ -34,7 +34,7 @@ #include #include -#include "../backend/compile_engine.h" +#include "../backend/te_compiler.h" #include "pattern_utils.h" namespace tvm { @@ -126,7 +126,8 @@ Expr AutoSchedulerLayoutRewriter::VisitExpr_(const CallNode* n) { CHECK(f) << "Could not find auto_scheduler.enter_layout_rewrite function."; (*f)(); - PrimFuncFor(GetRef(func), Target::Current(), [](std::string name) { return name; }); + tec::PrimFuncFor(GetRef(func), Target::Current(), + [](std::string name) { return name; }); f = runtime::Registry::Get("auto_scheduler.exit_layout_rewrite"); CHECK(f) << "Could not find ansor.exit_layout_rewrite function."; diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 28aeab60539c..38c3305d3194 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -262,7 +262,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopDeviceType(); - return OnDevice(expr, props.device_type, props.is_fixed); + return MaybeOnDevice(expr, props.device_type, props.is_fixed); } else { return DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 15784856edbf..b9fa0494d3b5 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -24,9 +24,11 @@ #include "./device_domains.h" +#include #include #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/device_copy.h" namespace tvm { @@ -47,20 +49,19 @@ constexpr size_t mix(size_t h1, size_t h2) { * See te_compiler.cc for where this rewriting occurs. */ DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { - auto tir_call_attrs = call_node->attrs.as(); - if (tir_call_attrs == nullptr) { - return {}; - } - if (tir_call_attrs->metadata.count("source_device") != 1 || - tir_call_attrs->metadata.count("dst_device") != 1) { - return {}; + if (call_node->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); + if (call_lowered_props.attrs.metadata.count("source_device") == 1 && + call_lowered_props.attrs.metadata.count("dst_device") == 1) { + ICHECK_EQ(call_lowered_props.arguments.size(), 1) << "device_copy is of arity 1"; + return {call_lowered_props.arguments[0], + static_cast( + Downcast(call_lowered_props.attrs.metadata["source_device"])->value), + static_cast( + Downcast(call_lowered_props.attrs.metadata["dst_device"])->value)}; + } } - ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; - return { - call_node->args[0], - static_cast( - Downcast(tir_call_attrs->metadata["source_device"])->value), - static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; + return {}; } } // namespace @@ -319,8 +320,12 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); + } else if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); + return DomainFor(call_lowered_props.lowered_func); } else { - // Defer to normal case where op can be an arbitrary expression. + // We still need to handle the case where the function / op is not lowered + // because the device planner runs before and after lowering. return DomainFor(call->op); } auto domain = MakeDomain(std::move(args_and_result)); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index dc61e79226b6..83429a9e616f 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/analysis/device_planner.cc + * \file src/relay/transforms/device_planner.cc * \brief Determines a unique device to hold the result of every Relay sub-expression. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index d545518c1c3c..c48a9b30967c 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -21,6 +21,7 @@ * \file constant_folding.cc */ #include +#include #include #include #include @@ -30,68 +31,80 @@ #include #include -#include "pattern_utils.h" +#include "../op/annotation/annotation.h" +#include "./device_aware_visitors.h" +#include "./pattern_utils.h" namespace tvm { namespace relay { +namespace transform { -using FInterpreter = runtime::TypedPackedFunc; - -class ConstantChecker : private ExprVisitor { - public: - // Check whether an expression is constant. The results are memoized. - bool Check(const Expr& expr) { - // The `ConstantNode` case is common enough that we check directly for the - // case here, to avoid the time overhead of dispatching through the vtable - // and the space overhead of memoizing always-true results. - if (expr.as()) { - return true; - } - const auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - VisitExpr(expr); - return memo_[expr]; // return memoized result or the default value false - } +namespace { +/*! + * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" + * annotation CallNode (which serves only to associate a device to the constant and has no + * operational effect). + */ +bool IsSimpleConstant(const Expr& expr) { + return AsIgnoringOnDevice(expr) != nullptr; +} - private: - std::unordered_map memo_; - - void VisitExpr_(const TupleNode* n) final { - bool result = true; - for (const auto& field : n->fields) { - if (!Check(field)) { - result = false; - break; - } - } - memo_[GetRef(n)] = result; +/*! + * \brief Returns whether \p expr \p IsSimpleConstant directly or is a tuple of + * \p IsComplexConstant expressions. + */ +bool IsComplexConstant(const Expr& expr) { + if (IsSimpleConstant(expr)) { + return true; + } else if (const auto* tuple_node = AsIgnoringOnDevice(expr)) { + return std::all_of(tuple_node->fields.begin(), tuple_node->fields.end(), IsComplexConstant); + } else { + return false; } -}; - -bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } - -TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck); +} // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. class ConstantFolder : public MixedModeMutator { public: explicit ConstantFolder(IRModule module) - : module_(module), + : module_(std::move(module)), device_copy_op_(Op::Get("device_copy")), shape_of_op_(Op::Get("shape_of")), vm_shape_of_op_(Op::Get("vm.shape_of")), cast_op_(Op::Get("cast")), ndarray_size_op_(Op::Get("ndarray_size")) {} - using MixedModeMutator::VisitExpr_; + private: + using ExprMutator::VisitExpr_; - Expr VisitExpr_(const LetNode* op) final { + Expr VisitExpr_(const LetNode* let_node) final { auto pre_visit = [this](const LetNode* op) { // Rely on the Memoizer to cache pre-visit values - Expr value = this->Mutate(op->value); - if (value.as()) { - this->memo_[op->var] = value; + Expr new_value = Mutate(op->value); + if (IsSimpleConstant(new_value)) { + // Inline new value (along with any on_device annotation wrapping it) at all occurrences of + // the variable. + // + // We need to retain any "on_device" annotation so that downstream 'device aware' + // passes can still retrieve the device for the constant in its new position(s). Eg: + // def @f(..., result_device_type=D) { + // let %x = on_device(... something we eval to a constant..., device_type=E) + // @f(..., %x, ...) + // } + // Here the default device is D, whereas the argument %x to @f is on E (and @f expects + // that). No on_device annotation is required in the call according to the convention used + // by the device-aware visitors. + // + // However once we've inlined the constant we need to insert an on_device, again to + // respect the convention used by the device-aware visitors. + // def @f(..., result_device_type=D) { + // @f(..., on_device(...the constant..., device_type=E), ...) + // } + VLOG(1) << "Replacing let-binding for " << op->var->name_hint() + << " with constant:" << std::endl + << PrettyPrint(new_value); + memo_[op->var] = new_value; } else { this->Mutate(op->var); } @@ -99,116 +112,117 @@ class ConstantFolder : public MixedModeMutator { auto post_visit = [this](const LetNode* op) { Expr expr = GetRef(op); // Rely on the Memoizer to cache pre-visit values - Expr value = this->Mutate(op->value); - if (value.as()) { - this->memo_[expr] = this->Mutate(op->body); + Expr new_value = this->Mutate(op->value); + if (IsSimpleConstant(new_value)) { + // The let-bound value has been inlined, drop the let-binding itself. + this->memo_[expr] = Mutate(op->body); } else { - Var var = Downcast(this->Mutate(op->var)); - Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + Var new_var = Downcast(this->Mutate(op->var)); + Expr new_body = this->Mutate(op->body); + if (new_var.same_as(op->var) && new_value.same_as(op->value) && + new_body.same_as(op->body)) { this->memo_[expr] = expr; } else { - this->memo_[expr] = Let(var, value, body); + this->memo_[expr] = Let(new_var, new_value, new_body, op->span); } } }; - ExpandANormalForm(op, pre_visit, post_visit); - return memo_[GetRef(op)]; + ExpandANormalForm(let_node, pre_visit, post_visit); + return memo_[GetRef(let_node)]; } - bool inside_primitive = false; - Expr VisitExpr_(const FunctionNode* op) final { - if (op->HasNonzeroAttr(attr::kPrimitive)) { - ICHECK_EQ(inside_primitive, false); - inside_primitive = true; - auto ret = ExprMutator::VisitExpr_(op); - inside_primitive = false; + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + ICHECK_EQ(inside_primitive_, false); + inside_primitive_ = true; + auto ret = ExprMutator::VisitExpr_(function_node); + inside_primitive_ = false; return ret; } else { - return ExprMutator::VisitExpr_(op); + return ExprMutator::VisitExpr_(function_node); } } - Expr VisitExpr_(const IfNode* op) final { - auto new_cond = ExprMutator::VisitExpr(op->cond); - if (auto const_cond = new_cond.as()) { - if (reinterpret_cast(const_cond->data->data)[0]) { - return ExprMutator::VisitExpr(op->true_branch); - } else { - return ExprMutator::VisitExpr(op->false_branch); - } + Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { + Call pre_call = GetRef(pre_call_node); + if (inside_primitive_) { + return pre_call; } - return ExprMutator::VisitExpr_(op); - } - Expr Rewrite_(const CallNode* call, const Expr& post) final { - if (inside_primitive) { - return GetRef(call); - } - static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); - - auto origin_args = call->args; - call = post.as(); - // We don't constant fold function with zero arguments. - // This is a heuristic that is useful. - // For example it is harmful to fold ones(shape=(4, 5)). - if (call->args.size() == 0) return post; - const OpNode* op = call->op.as(); - if (op == nullptr) return post; - // skip stateful ops. - if (op_stateful.get(GetRef(op), false)) return post; - // Try to evaluate shape_of op - if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) { - return EvaluateShapeOf(post, origin_args, call->attrs); - } + Call post_call = Downcast(post); - if (call->op == ndarray_size_op_) { - return EvaluateNdarraySize(post, origin_args, call->attrs); + if (post_call->args.empty()) { + // We don't constant fold function with zero arguments. + // This is a heuristic that is useful. + // For example it is harmful to fold ones(shape=(4, 5)). + return std::move(pre_call); } - // We should think about potentially constant evaluation over these ops too. static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); - if (const auto* call_node = call->op.as()) { - Op op = GetRef(call_node); - if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) { - return GetRef(call); - } - } - bool all_const_args = true; - for (Expr arg : call->args) { - if (!checker_.Check(arg)) { - all_const_args = false; - } + const auto* op_node = post_call->op.as(); + if (op_node == nullptr) { + // Only evaluate primitives. + return std::move(post_call); } - if (all_const_args) { - return ConstEvaluate(post); - } else { - return post; + Op op = GetRef(op_node); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); + if (op_stateful.get(op, false)) { + // skip stateful ops. + return std::move(post_call); } + // Try to evaluate shape_of and ndarray_size ops + // Use the original call rather than new_call here since it still has valid checked_type + // fields. These operators don't care about the value of their argument anyway. + if (Optional opt_result = EvaluateShapeOf(pre_call)) { + return opt_result.value(); + } + // Use the original call rather than new_call here since it still has valid checked_type + // fields. This operator doesn't care about the value of its argument anyway. + if (Optional opt_result = EvaluateNdarraySize(pre_call)) { + return opt_result.value(); + } + if ((fnoncomputational.count(op) && fnoncomputational[op]) || op == device_copy_op_ || + op == shape_of_op_ || op == vm_shape_of_op_ || op == ndarray_size_op_) { + // We should think about potentially constant evaluation over these ops too. + return std::move(post_call); + } + if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) { + // At least one non-constant argument. + return std::move(post_call); + } + // During evaluation we have obviously lost all on_device annotations. However any + // on_device wrapping this call will be left in place. + return ConstEvaluate(post_call); } - Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { - op = post.as(); - if (const auto* tuple = op->tuple.as()) { - return tuple->fields[op->index]; - } else { - return post; + Expr VisitExpr_(const IfNode* if_node) final { + If new_if = Downcast(ExprMutator::VisitExpr_(if_node)); + if (const auto* const_node = AsIgnoringOnDevice(new_if->cond)) { + if (reinterpret_cast(const_node->data->data)[0]) { + return new_if->true_branch; + } else { + return new_if->false_branch; + } } + return std::move(new_if); } - private: - // Internal constant checker - ConstantChecker checker_; - // Module - IRModule module_; - - // Cache the following ops for equivalence checking in this pass. - const Op& device_copy_op_; - const Op& shape_of_op_; - const Op& vm_shape_of_op_; - const Op& cast_op_; - const Op& ndarray_size_op_; + Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, + const Expr& post_tuple_get_item) final { + const auto* post_tuple_get_item_node = post_tuple_get_item.as(); + if (const auto* tuple_node = AsIgnoringOnDevice(post_tuple_get_item_node->tuple)) { + Expr result = tuple_node->fields[tuple_get_item_node->index]; + OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); + if (props.body.defined()) { + // (on_device((x, y, z), device_type=D).1 ==> on_device(y, device_type=D) + return MaybeOnDevice(result, props.device_type, props.is_fixed); + } else { + return result; + } + } + return std::move(post_tuple_get_item); + } // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { @@ -224,35 +238,53 @@ class ConstantFolder : public MixedModeMutator { return Tuple(fields); } else { LOG(FATAL) << "Cannot handle " << value->GetTypeKey(); - return Expr(); + return {}; } } + // Constant evaluate an expression. - Expr ConstEvaluate(Expr expr) { + Expr ConstEvaluate(const Expr& expr) { + VLOG_CONTEXT << "ConstEvaluate"; + VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr); + + // We'll invoke the interpreter using the generic CPU device and target. Technically there's + // no guarantee the results we bitwise equal what we'd get on the true device, however to + // support cross-compilation we don't want to assume the true device is available. Device dev; dev.device_type = kDLCPU; dev.device_id = 0; Target target = Target("llvm"); - // use a fresh build context in case we are already in a build context. + // Use a fresh build context in case we are already in a build context. // needed for both execution and creation(due to JIT) With fresh_build_ctx(transform::PassContext::Create()); - return ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); + Expr result = + ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target)); + VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); + return result; } - // Evaluate a call to the shape_of operator for tensors with constant - // shapes. - Expr EvaluateShapeOf(Expr expr, Array args, Attrs attrs) { - Expr input = args[0]; - const auto* param = attrs.as(); + /*! + * \brief Returns constant shape result of \p call if it of form \p shape_of(e) and \p e has + * a non-dynamic tensor shape. Returns null otherwise. + */ + Optional EvaluateShapeOf(const Call& call) { + if (call->op != shape_of_op_ && call->op != vm_shape_of_op_) { + return {}; + } + + VLOG(1) << "Evaluating for shape_of:" << std::endl << PrettyPrint(call); + ICHECK_EQ(call->args.size(), 1); + const auto* param = call->attrs.as(); ICHECK(param != nullptr); + Expr input = call->args[0]; tvm::Array ishape; - if (auto opt = GetConstantShape(input)) { - ishape = opt.value(); + if (Optional> opt_shape = GetConstantShape(input)) { + ishape = opt_shape.value(); } else { - return expr; + return {}; } // Get the constant shape @@ -261,26 +293,26 @@ class ConstantFolder : public MixedModeMutator { dev.device_id = 0; runtime::NDArray value; DLDataType cdtype = DataType::Int(32); - if (ishape.size() == 0) { + if (ishape.empty()) { value = runtime::NDArray::Empty({}, cdtype, dev); } else { ICHECK_NE(ishape.size(), 0); std::vector cshape = {static_cast(ishape.size())}; value = runtime::NDArray::Empty(cshape, cdtype, dev); - int32_t* dims = static_cast(value->data); + auto* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { - if (const IntImmNode* dim = ishape[i].as()) { + if (const auto* dim = ishape[i].as()) { dims[i] = dim->value; } else { - return expr; + return {}; } } } Constant shape = Downcast(ObjectToExpr(value)); - if (shape->data.Shape().size() == 0 && GetScalarFromConstant(shape) == 0) { + if (shape->data.Shape().empty() && GetScalarFromConstant(shape) == 0) { auto ndarray = runtime::NDArray::Empty({}, cdtype, dev); shape = Constant(ndarray); } @@ -288,18 +320,25 @@ class ConstantFolder : public MixedModeMutator { return CastValue(shape, param->dtype); } - // Evaluate a call to the ndarray_size operator for tensors with constant - // shapes. - Expr EvaluateNdarraySize(Expr expr, Array args, Attrs attrs) { - Expr input = args[0]; - const auto* param = attrs.as(); + /*! + * \brief Returns the constant NDArray size of result of \p call if it is of the form + * \p ndarray_size(e) and \p e has non-dynamic tensor type. Returns null otherwise. + */ + Optional EvaluateNdarraySize(const Call& call) { + if (call->op != ndarray_size_op_) { + return {}; + } + VLOG(1) << "Evaluating for ndarray_size:" << std::endl << PrettyPrint(call); + ICHECK_EQ(call->args.size(), 1); + Expr input = call->args[0]; + const auto* param = call->attrs.as(); ICHECK(param != nullptr); tvm::Array ishape; - if (auto opt = GetConstantShape(input)) { - ishape = opt.value(); + if (Optional> opt_shape = GetConstantShape(input)) { + ishape = opt_shape.value(); } else { - return expr; + return {}; } // Get the constant size @@ -309,17 +348,17 @@ class ConstantFolder : public MixedModeMutator { runtime::NDArray value; DLDataType cdtype = DataType::Int(32); value = runtime::NDArray::Empty({}, cdtype, dev); - int32_t* data = static_cast(value->data); - if (ishape.size() == 0) { + auto* data = static_cast(value->data); + if (ishape.empty()) { *data = 0; } else { *data = 1; using ::tvm::tir::IntImmNode; for (size_t i = 0; i < ishape.size(); ++i) { - if (const IntImmNode* dim = ishape[i].as()) { + if (const auto* dim = ishape[i].as()) { *data *= dim->value; } else { - return expr; + return {}; } } } @@ -337,31 +376,57 @@ class ConstantFolder : public MixedModeMutator { } Optional> GetConstantShape(const Expr& input) { - tvm::Array ishape; - if (const ConstantNode* op = input.as()) { - ishape = op->tensor_type()->shape; + if (const auto* const_node = AsIgnoringOnDevice(input)) { + // TODO(mbs): This is not necessary since we only ever ask for the shapes for + // pre-rewritten expressions which will always have a checked_type. + return const_node->tensor_type()->shape; } else if (input->checked_type_.defined()) { - ishape = input->checked_type().as()->shape; + return input->checked_type().as()->shape; } else { - return Optional>(nullptr); + return {}; } - - return Optional>(ishape); } + + // Module + IRModule module_; + + // Cache the following ops for equivalence checking in this pass. + const Op& device_copy_op_; + const Op& shape_of_op_; + const Op& vm_shape_of_op_; + const Op& cast_op_; + const Op& ndarray_size_op_; + + // True if currently within a "primitive" Relay Function. + bool inside_primitive_ = false; }; -Expr FoldConstant(const Expr& expr, const IRModule& mod) { - return ConstantFolder(mod).Mutate(expr); -} +} // namespace -TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant); +TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexConstant); -namespace transform { +/*! + * \brief Returns \p expr with any constants expressions evaluated and let-bound constants + * inlined. Returns \p expr unchanged if no change. + * + * CAUTION: The importers rely on this function returning \p expr unchanged to preserve sharing + * from their p.o.v. Furthermore, this function can be called before conversion to ANF so + * we must avoid all recursion. + */ +Expr FoldConstantExpr(const Expr& expr, const IRModule& mod) { + VLOG_CONTEXT << "FoldConstantExpr"; + VLOG(1) << "folding:" << std::endl << PrettyPrint(expr); + Expr result = ConstantFolder(mod).VisitExpr(expr); + VLOG(1) << "folded to:" << std::endl << PrettyPrint(result); + return result; +} + +TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstantExpr); Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(FoldConstant(f, m)); + return Downcast(FoldConstantExpr(f, m)); }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } @@ -369,6 +434,5 @@ Pass FoldConstant() { TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform - } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 71917c31ec00..a328eaa82aa2 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ #include "../backend/te_compiler.h" #include "../backend/te_compiler_cache.h" #include "../op/annotation/annotation.h" +#include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" @@ -62,8 +64,9 @@ inline Constant MakeConstant(const std::vector& value) { } inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, - Array assert_shape) { - auto offset = MakeConstantScalar(DataType::Int(64), 0); + Array assert_shape, DLDeviceType offset_device_type) { + auto offset = + OnDevice(MakeConstantScalar(DataType::Int(64), 0), offset_device_type, /*is_fixed=*/true); return AllocTensor(storage, offset, shape, dtype, assert_shape); } @@ -73,12 +76,11 @@ bool IsReshapeOnly(const Expr& expr) { return func->HasNonzeroAttr(attr::kReshapeOnly); } if (const CallNode* call = expr.as()) { - if (call->attrs.defined()) { - if (auto tir_call_attrs = call->attrs.as()) { - Map metadata = tir_call_attrs->metadata; - return metadata.count(attr::kReshapeOnly) && - (Downcast(metadata[attr::kReshapeOnly])->value == 1); - } + if (call->op == CallLoweredOp()) { + CallLoweredProps call_lowered_props = GetCallLoweredProps(call); + Map metadata = call_lowered_props.attrs.metadata; + return metadata.count(attr::kReshapeOnly) && + (Downcast(metadata[attr::kReshapeOnly])->value == 1); } } return false; @@ -267,8 +269,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto sto = scope->Push(var, value); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), - dev.device_type, /*is_fixed=*/true); + auto tensor = OnDevice( + AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape, cpu_device_.device_type), + dev.device_type, /*is_fixed=*/true); Var tensor_var("tensor_" + name_hint, Type(nullptr)); return scope->Push(tensor_var, tensor); } @@ -367,14 +370,16 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), + auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape, + cpu_device_.device_type), dev.device_type, /*is_fixed=*/true); Var out_var("out_" + std::to_string(i), Type(nullptr)); outs.push_back(scope->Push(out_var, alloc)); } Tuple tuple_outs(outs); - auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), dev.device_type, /*is_fixed=*/true); + auto call = InvokeTVMOp(func, ins, tuple_outs); + auto invoke = OnDevice(call, dev.device_type, /*is_fixed=*/true); scope->Push(invoke); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); @@ -394,7 +399,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = MakeConstant(shape); + shape_expr = OnDevice(MakeConstant(shape), cpu_device_.device_type, /*is_fixed=*/true); } return ReshapeTensor(new_args[0], shape_expr, ret_ty->shape); } @@ -415,7 +420,6 @@ Pass ManifestAlloc(Target target_host, Map targets) { CheckAndUpdateHostConsistency(&targets, &target_host); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { - DLOG(INFO) << "tvm::relay::transform::ManifestAlloc"; // We need to mutate module, therefore making a copy of it. mod.CopyOnWrite(); mod->ImportFromStd("core.rly"); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index f74cf983ccae..6e52cbfbe55a 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -43,6 +43,7 @@ #include #include "../analysis/annotated_region_set.h" +#include "../backend/name_transforms.h" #include "../backend/utils.h" #include "pass_utils.h" @@ -501,7 +502,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); if (func->GetAttr(attr::kCompiler).defined()) { - auto fn_name_mangled = mangle_fn_(pair.first->name_hint); + auto fn_name_mangled = relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)); GlobalVar gvar = GlobalVar(fn_name_mangled); mangled_gvars_[pair.first->name_hint] = gvar; } @@ -519,7 +520,8 @@ class NameMangleExtFuncs : public MixedModeMutator { if (func->GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; - new_dict.Set(tvm::attr::kGlobalSymbol, String(mangle_fn_(pair.first->name_hint))); + new_dict.Set(tvm::attr::kGlobalSymbol, + String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, DictAttrs(new_dict)); new_module->Add(mangled_gvars_[pair.first->name_hint], func); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 5ca6d86b1d52..6d74e48e871e 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -666,7 +666,7 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Type checked_type = solver_->Resolve(it->second.checked_type); if (checked_type.as() != nullptr) { - this->solver_->diag_ctx_.Emit( + this->solver_->Emit( Diagnostic::Error(op->span) << "The type inference pass was unable to infer a type for this expression.\n" << "This usually occurs when an operator call is under constrained in some way," diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index aa718a303744..90897a9542b6 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -49,7 +49,10 @@ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 + +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! \brief Maximum length of a PackedFunc function name. */ #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 diff --git a/src/runtime/crt/graph_executor/graph_executor.c b/src/runtime/crt/graph_executor/graph_executor.c index 34e81c7d33b1..3fea408d9760 100644 --- a/src/runtime/crt/graph_executor/graph_executor.c +++ b/src/runtime/crt/graph_executor/graph_executor.c @@ -77,7 +77,7 @@ int NodeEntry_Load(TVMGraphExecutorNodeEntry* entry, JSONReader* reader) { void TVMGraphExecutorNode_LoadAttrs(TVMGraphExecutorNode* node, JSONReader* reader, TVMOpParam* param) { int bitmask = 0; - char key[20], value[120]; + char key[20], value[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; memset(param, 0, sizeof(TVMOpParam)); memset(key, 0, sizeof(key)); memset(value, 0, sizeof(value)); @@ -796,13 +796,13 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl char* names = NULL; DLDevice dev = {kDLCPU, 0}; tvm_crt_error_t err = TVMPlatformMemoryAllocate( - TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count, dev, (void**)&names); + TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count, dev, (void**)&names); if (err != kTvmErrorNoError) { fprintf(stderr, "memory allocate error: %08x", err); status = -1; return status; } - memset(names, 0, TVM_CRT_MAX_STRLEN_FUNCTION_NAME * executor->nodes_count); + memset(names, 0, TVM_CRT_MAX_STRLEN_PARAM_NAME * executor->nodes_count); uint64_t names_count; int idx; memcpy(&names_count, bptr, sizeof(names_count)); @@ -811,11 +811,11 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl uint64_t name_length; memcpy(&name_length, bptr, sizeof(name_length)); bptr += sizeof(name_length); - if (name_length >= TVM_CRT_MAX_STRLEN_FUNCTION_NAME) { + if (name_length >= TVM_CRT_MAX_STRLEN_PARAM_NAME) { fprintf(stderr, "Error: function name longer than expected.\n"); status = -1; } - memcpy(names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, bptr, name_length); + memcpy(names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, bptr, name_length); bptr += name_length; } @@ -831,9 +831,9 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl for (idx = 0; idx < size; idx++) { int32_t in_idx = - TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + TVMGraphExecutor_GetInputIndex(executor, names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx); + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx); uint32_t eid = TVMGraphExecutor_GetEntryId(executor, executor->input_nodes[in_idx], 0); if (!(eid < executor->data_entry_count)) { fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, @@ -859,7 +859,7 @@ int TVMGraphExecutor_LoadParams(TVMGraphExecutor* executor, const char* param_bl #if TVM_CRT_DEBUG TVMNDArray* entry = &(executor->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", - names + TVM_CRT_MAX_STRLEN_FUNCTION_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, + names + TVM_CRT_MAX_STRLEN_PARAM_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) #endif // TVM_CRT_DEBUG } @@ -1181,13 +1181,6 @@ int TVMGraphExecutor_Init(TVMGraphExecutor* executor, const char* graph_json, return status; } status = TVMGraphExecutor_SetupOpExecs(executor); - if (status != 0) { - if (status != 0) { - return status; - } - - return status; - } return status; } diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h index c67c43357363..d4429308b650 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/graph_executor/graph_executor.h @@ -60,7 +60,7 @@ typedef struct TVMGraphExecutorNode { // operator type in string char op_type[16]; // name of the op - char name[120]; + char name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; // parameters TVMOpParam param; // inputs diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 33a87c9a2be2..b4d7b41b7f4a 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -112,14 +112,14 @@ class CUDADeviceAPI final : public DeviceAPI { ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void* ret; if (dev.device_type == kDLCUDAHost) { - DLOG(INFO) << "allocating " << nbytes << "bytes on host"; + VLOG(1) << "allocating " << nbytes << "bytes on host"; CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); size_t free_mem, total_mem; CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem)); - DLOG(INFO) << "allocating " << nbytes << " bytes on device, with " << free_mem - << " bytes currently free out of " << total_mem << " bytes available"; + VLOG(1) << "allocating " << nbytes << " bytes on device, with " << free_mem + << " bytes currently free out of " << total_mem << " bytes available"; CUDA_CALL(cudaMalloc(&ret, nbytes)); } return ret; @@ -127,11 +127,11 @@ class CUDADeviceAPI final : public DeviceAPI { void FreeDataSpace(Device dev, void* ptr) final { if (dev.device_type == kDLCUDAHost) { - DLOG(INFO) << "freeing host memory"; + VLOG(1) << "freeing host memory"; CUDA_CALL(cudaFreeHost(ptr)); } else { CUDA_CALL(cudaSetDevice(dev.device_id)); - DLOG(INFO) << "freeing device memory"; + VLOG(1) << "freeing device memory"; CUDA_CALL(cudaFree(ptr)); } } diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index c439bde82497..81eb30ee12d2 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -37,62 +37,102 @@ namespace tvm { namespace runtime { -// Dynamic shared libary. -// This is the default module TVM used for host-side AOT +/*! + * \brief Dynamic shared library object used to load + * and retrieve symbols by name. This is the default + * module TVM uses for host-side AOT compilation. + */ class DSOLibrary final : public Library { public: - ~DSOLibrary() { - if (lib_handle_) Unload(); - } - void Init(const std::string& name) { Load(name); } - - void* GetSymbol(const char* name) final { return GetSymbol_(name); } + ~DSOLibrary(); + /*! + * \brief Initialize by loading and storing + * a handle to the underlying shared library. + * \param name The string name/path to the + * shared library over which to initialize. + */ + void Init(const std::string& name); + /*! + * \brief Returns the symbol address within + * the shared library for a given symbol name. + * \param name The name of the symbol. + * \return The symbol. + */ + void* GetSymbol(const char* name) final; private: - // Platform dependent handling. + /*! \brief Private implementation of symbol lookup. + * Implementation is operating system dependent. + * \param The name of the symbol. + * \return The symbol. + */ + void* GetSymbol_(const char* name); + /*! \brief Implementation of shared library load. + * Implementation is operating system dependent. + * \param The name/path of the shared library. + */ + void Load(const std::string& name); + /*! \brief Implementation of shared library unload. + * Implementation is operating system dependent. + */ + void Unload(); + #if defined(_WIN32) - // library handle + //! \brief Windows library handle HMODULE lib_handle_{nullptr}; - - void* GetSymbol_(const char* name) { - return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) - } - - // Load the library - void Load(const std::string& name) { - // use wstring version that is needed by LLVM. - std::wstring wname(name.begin(), name.end()); - lib_handle_ = LoadLibraryW(wname.c_str()); - ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; - } - - void Unload() { - FreeLibrary(lib_handle_); - lib_handle_ = nullptr; - } #else - // Library handle + // \brief Linux library handle void* lib_handle_{nullptr}; - // load the library - void Load(const std::string& name) { - lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - ICHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name << " " << dlerror(); - } - - void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } - - void Unload() { - dlclose(lib_handle_); - lib_handle_ = nullptr; - } #endif }; -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { +DSOLibrary::~DSOLibrary() { + if (lib_handle_) Unload(); +} + +void DSOLibrary::Init(const std::string& name) { Load(name); } + +void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); } + +#if defined(_WIN32) + +void* DSOLibrary::GetSymbol_(const char* name) { + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) +} + +void DSOLibrary::Load(const std::string& name) { + // use wstring version that is needed by LLVM. + std::wstring wname(name.begin(), name.end()); + lib_handle_ = LoadLibraryW(wname.c_str()); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; +} + +void DSOLibrary::Unload() { + FreeLibrary(lib_handle_); + lib_handle_ = nullptr; +} + +#else + +void DSOLibrary::Load(const std::string& name) { + lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); +} + +void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } + +void DSOLibrary::Unload() { + dlclose(lib_handle_); + lib_handle_ = nullptr; +} + +#endif + +ObjectPtr CreateDSOLibraryObject(std::string library_path) { auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); -}); + n->Init(library_path); + return n; +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/android/hexagon_device.h b/src/runtime/hexagon/android/hexagon_device.h new file mode 100644 index 000000000000..552b8f971369 --- /dev/null +++ b/src/runtime/hexagon/android/hexagon_device.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ + +#include +#include + +#include +#include + +#include "../../meta_data.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +/*! + * \brief Low-level interface for communicating with Hexagon devices. + */ +class Device { + public: + /*! + * \brief Allocate memory on device. + * \param size Requested size. + * \param align Requested alignment. + * \return Pointer (local to the device) of the allocated memory, + * or nullptr if allocation failed. + */ + virtual void* Alloc(unsigned size, unsigned align) = 0; + /*! + * \brief Release allocated memory on device. + * \param ptr Pointer to memory previously allocated by \ref Alloc. + */ + virtual void Free(void* ptr) = 0; + /*! + * \brief Allocate VTCM memory on device. + * \param size Requested size. + * \param align Requested alignment. + * \return Pointer (local to the device) of the allocated memory, + * or nullptr if allocation failed. + */ + virtual void* AllocVtcm(unsigned size, unsigned align) = 0; + /*! + * \brief Release allocated VTCM memory on device. + * \param ptr Pointer to memory previously allocated by \ref AllocVtcm. + */ + virtual void FreeVtcm(void* ptr) = 0; + /*! + * \brief Copy a block of data on device to another location on the device. + * \param dst Pointer (local to device) to the destination buffer. + * \param src Pointer (local to device) of the source buffer. + * \param len Number of bytes to copy. + */ + virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; + /*! + * \brief Copy a block of data from device to host. + * \param host_dst Pointer (local to host) to the destination buffer. + * \param src Pointer (local to device) to the source buffer. + * \param len Number of bytes to copy. + */ + virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; + /*! + * \brief Copy a block of data from host to device. + * \param dst Pointer (local to device) to the destination buffer. + * \param host_src Pointer (local to host) to the source buffer. + * \param len Number of bytes to copy. + */ + virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; + /*! + * \brief Load a module (typically a shared library) into device. + * \param data Name of the shared library. + * \param fmt Format of the library (currently ignored). + * \return Pointer to the loaded module. + * \note Currently only one module can be loaded at any given time. + */ + virtual void* Load(const std::string& data, const std::string& fmt) = 0; + /*! + * \brief Unload a module from device. + * \param mod Pointer to a loaded module returned by \ref Load. + */ + virtual void Unload(void* mod) = 0; + /*! + * \brief Find the address of an object in the currently loaded module. + * \param sym Name of the object. + * \return Address of the located object, or nullptr if object was + * not found. + */ + virtual void* Resolve(const std::string& sym) = 0; + /*! + * \brief Invoke a function on device with given arguments. + * \param func Address (local to device) of the function to call. + * \param scalar Pointer to an array of 32-bit values that will be + * passed via consecutive registers: r0..r5. This array + * includes dummy values for skipped registers. + * \param sc_num Number of values in the "scalar" array. + * \param stack Pointer to an array of 32-bit values that will be + * passed on the stack. This array includes dummy values + * for padding. + * \param st_num Number of values in the "stack" array. + */ + virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) = 0; + + virtual ~Device() = 0; + + static std::shared_ptr Global(); + static bool ValidateDeviceId(decltype(DLDevice::device_id) device_id) { + // Only supporting a single device for now. + return device_id == 0; + } +}; + +} // namespace hexagon + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_HEXAGON_ANDROID_HEXAGON_DEVICE_H_ diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/android/hexagon_device_api.cc similarity index 99% rename from src/runtime/hexagon/hexagon_device_api.cc rename to src/runtime/hexagon/android/hexagon_device_api.cc index a07a7c683026..ec50b4bf93a5 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/android/hexagon_device_api.cc @@ -24,7 +24,7 @@ #include #include -#include "hexagon_module.h" +#include "hexagon_device.h" namespace tvm { namespace runtime { diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/android/hexagon_module.cc similarity index 99% rename from src/runtime/hexagon/hexagon_module.cc rename to src/runtime/hexagon/android/hexagon_module.cc index 41aa5855ceeb..e386daf7dc7c 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/android/hexagon_module.cc @@ -17,7 +17,7 @@ * under the License. */ -#include "hexagon_module.h" +#include "../hexagon_module.h" #ifdef __ANDROID__ #include @@ -31,8 +31,8 @@ #include #include -#include "../file_utils.h" -#include "../meta_data.h" +#include "../../file_utils.h" +#include "hexagon_device.h" namespace tvm { namespace runtime { diff --git a/src/runtime/hexagon/hexagon_posix.cc b/src/runtime/hexagon/android/hexagon_posix.cc similarity index 100% rename from src/runtime/hexagon/hexagon_posix.cc rename to src/runtime/hexagon/android/hexagon_posix.cc diff --git a/src/runtime/hexagon/sim/driver/CMakeLists.txt b/src/runtime/hexagon/android/sim/driver/CMakeLists.txt similarity index 95% rename from src/runtime/hexagon/sim/driver/CMakeLists.txt rename to src/runtime/hexagon/android/sim/driver/CMakeLists.txt index dbac99534383..ddcec9169211 100644 --- a/src/runtime/hexagon/sim/driver/CMakeLists.txt +++ b/src/runtime/hexagon/android/sim/driver/CMakeLists.txt @@ -61,10 +61,10 @@ add_executable(sim_dev ${SOURCE_FILES}) target_include_directories(sim_dev PUBLIC "." PUBLIC ".." - PUBLIC "../../../../../include" + PUBLIC "../../../../../../include" ) target_include_directories(sim_dev SYSTEM - PUBLIC "../../../../../3rdparty/dlpack/include" + PUBLIC "../../../../../../3rdparty/dlpack/include" ) target_link_libraries(sim_dev "-ldl") diff --git a/src/runtime/hexagon/sim/driver/README.md b/src/runtime/hexagon/android/sim/driver/README.md similarity index 100% rename from src/runtime/hexagon/sim/driver/README.md rename to src/runtime/hexagon/android/sim/driver/README.md diff --git a/src/runtime/hexagon/sim/driver/fake_pthread.cc b/src/runtime/hexagon/android/sim/driver/fake_pthread.cc similarity index 100% rename from src/runtime/hexagon/sim/driver/fake_pthread.cc rename to src/runtime/hexagon/android/sim/driver/fake_pthread.cc diff --git a/src/runtime/hexagon/sim/driver/pthread.h b/src/runtime/hexagon/android/sim/driver/pthread.h similarity index 94% rename from src/runtime/hexagon/sim/driver/pthread.h rename to src/runtime/hexagon/android/sim/driver/pthread.h index 7ec74b4f99f5..b4d559c44f8e 100644 --- a/src/runtime/hexagon/sim/driver/pthread.h +++ b/src/runtime/hexagon/android/sim/driver/pthread.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ -#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ #define _PROVIDE_POSIX_TIME_DECLS 1 #include @@ -89,4 +89,4 @@ pthread_t pthread_self(void); } #endif -#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_PTHREAD_H_ diff --git a/src/runtime/hexagon/sim/driver/sched.h b/src/runtime/hexagon/android/sim/driver/sched.h similarity index 84% rename from src/runtime/hexagon/sim/driver/sched.h rename to src/runtime/hexagon/android/sim/driver/sched.h index cc63630f2072..621ef218b795 100644 --- a/src/runtime/hexagon/sim/driver/sched.h +++ b/src/runtime/hexagon/android/sim/driver/sched.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ -#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ #ifdef __cplusplus extern "C" { @@ -28,4 +28,4 @@ int sched_yield(void); } #endif -#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_DRIVER_SCHED_H_ diff --git a/src/runtime/hexagon/sim/driver/sim_device.cc b/src/runtime/hexagon/android/sim/driver/sim_device.cc similarity index 100% rename from src/runtime/hexagon/sim/driver/sim_device.cc rename to src/runtime/hexagon/android/sim/driver/sim_device.cc diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc similarity index 99% rename from src/runtime/hexagon/sim/hexagon_device_sim.cc rename to src/runtime/hexagon/android/sim/hexagon_device_sim.cc index 14ab4c30e2f2..250259832597 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc @@ -32,7 +32,7 @@ #include #include -#include "../hexagon_module.h" +#include "../hexagon_device.h" #include "HexagonWrapper.h" #include "hexagon_sim_proto.h" @@ -121,7 +121,7 @@ struct non_const_str { ICHECK_EQ(pointers_.size(), 1); return pointers_[0]; } - operator char* *() { return pointers_.data(); } + operator char**() { return pointers_.data(); } private: std::vector pointers_; diff --git a/src/runtime/hexagon/sim/hexagon_sim_proto.h b/src/runtime/hexagon/android/sim/hexagon_sim_proto.h similarity index 90% rename from src/runtime/hexagon/sim/hexagon_sim_proto.h rename to src/runtime/hexagon/android/sim/hexagon_sim_proto.h index 2a41536037df..888752623262 100644 --- a/src/runtime/hexagon/sim/hexagon_sim_proto.h +++ b/src/runtime/hexagon/android/sim/hexagon_sim_proto.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_SIM_HEXAGON_SIM_PROTO_H_ -#define TVM_RUNTIME_HEXAGON_SIM_HEXAGON_SIM_PROTO_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ // Protocol: @@ -70,4 +70,4 @@ struct MsgCall { uint32_t data[]; // 12 } __attribute__((packed)); -#endif // TVM_RUNTIME_HEXAGON_SIM_HEXAGON_SIM_PROTO_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_SIM_HEXAGON_SIM_PROTO_H_ diff --git a/src/runtime/hexagon/target/fastrpc/CMakeLists.txt b/src/runtime/hexagon/android/target/fastrpc/CMakeLists.txt similarity index 100% rename from src/runtime/hexagon/target/fastrpc/CMakeLists.txt rename to src/runtime/hexagon/android/target/fastrpc/CMakeLists.txt diff --git a/src/runtime/hexagon/target/fastrpc/README.md b/src/runtime/hexagon/android/target/fastrpc/README.md similarity index 100% rename from src/runtime/hexagon/target/fastrpc/README.md rename to src/runtime/hexagon/android/target/fastrpc/README.md diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl b/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote.idl similarity index 100% rename from src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl rename to src/runtime/hexagon/android/target/fastrpc/include/tvm_remote.idl diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl b/src/runtime/hexagon/android/target/fastrpc/include/tvm_remote_nd.idl similarity index 100% rename from src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl rename to src/runtime/hexagon/android/target/fastrpc/include/tvm_remote_nd.idl diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.cc similarity index 100% rename from src/runtime/hexagon/target/fastrpc/src/tvm_hvx.cc rename to src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.cc diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h b/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h similarity index 95% rename from src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h rename to src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h index 2fe947574bbb..3d14252ad648 100644 --- a/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h +++ b/src/runtime/hexagon/android/target/fastrpc/src/tvm_hvx.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ // Utility providing functions for accessing the Hexagon Vector Extensions // (HVX) hardware. @@ -150,4 +150,4 @@ int cleanup_mt_job(const config_t* hvx_config); } // namespace hvx -#endif // TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_FASTRPC_SRC_TVM_HVX_H_ diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_imp.cc similarity index 100% rename from src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc rename to src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_imp.cc diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_nd_imp.cc similarity index 100% rename from src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc rename to src/runtime/hexagon/android/target/fastrpc/src/tvm_remote_nd_imp.cc diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc b/src/runtime/hexagon/android/target/fastrpc/src/tvm_wrap_pthread.cc similarity index 100% rename from src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc rename to src/runtime/hexagon/android/target/fastrpc/src/tvm_wrap_pthread.cc diff --git a/src/runtime/hexagon/target/hexagon_device_target.cc b/src/runtime/hexagon/android/target/hexagon_device_target.cc similarity index 99% rename from src/runtime/hexagon/target/hexagon_device_target.cc rename to src/runtime/hexagon/android/target/hexagon_device_target.cc index ee326ca0b159..a542c5a3e3a2 100644 --- a/src/runtime/hexagon/target/hexagon_device_target.cc +++ b/src/runtime/hexagon/android/target/hexagon_device_target.cc @@ -27,7 +27,7 @@ #include #include -#include "../hexagon_module.h" +#include "../hexagon_device.h" #include "AEEStdErr.h" #include "fastrpc/include/tvm_remote.h" #include "hexagon_dsprpcapi.h" diff --git a/src/runtime/hexagon/target/hexagon_dsprpcapi.cc b/src/runtime/hexagon/android/target/hexagon_dsprpcapi.cc similarity index 100% rename from src/runtime/hexagon/target/hexagon_dsprpcapi.cc rename to src/runtime/hexagon/android/target/hexagon_dsprpcapi.cc diff --git a/src/runtime/hexagon/target/hexagon_dsprpcapi.h b/src/runtime/hexagon/android/target/hexagon_dsprpcapi.h similarity index 96% rename from src/runtime/hexagon/target/hexagon_dsprpcapi.h rename to src/runtime/hexagon/android/target/hexagon_dsprpcapi.h index e4711e3da584..a3d186e302e3 100644 --- a/src/runtime/hexagon/target/hexagon_dsprpcapi.h +++ b/src/runtime/hexagon/android/target/hexagon_dsprpcapi.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_DSPRPCAPI_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_DSPRPCAPI_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ #ifdef __ANDROID__ #include @@ -189,4 +189,4 @@ class DspRpcAPI { } // namespace tvm #endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_DSPRPCAPI_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_DSPRPCAPI_H_ diff --git a/src/runtime/hexagon/target/hexagon_stubapi.cc b/src/runtime/hexagon/android/target/hexagon_stubapi.cc similarity index 100% rename from src/runtime/hexagon/target/hexagon_stubapi.cc rename to src/runtime/hexagon/android/target/hexagon_stubapi.cc diff --git a/src/runtime/hexagon/target/hexagon_stubapi.h b/src/runtime/hexagon/android/target/hexagon_stubapi.h similarity index 98% rename from src/runtime/hexagon/target/hexagon_stubapi.h rename to src/runtime/hexagon/android/target/hexagon_stubapi.h index fba22b10247c..feb329f5cef2 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.h +++ b/src/runtime/hexagon/android/target/hexagon_stubapi.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_STUBAPI_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_STUBAPI_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ #ifdef __ANDROID__ #include @@ -312,4 +312,4 @@ class StubAPI { } // namespace tvm #endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_STUBAPI_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_STUBAPI_H_ diff --git a/src/runtime/hexagon/target/hexagon_target_log.h b/src/runtime/hexagon/android/target/hexagon_target_log.h similarity index 87% rename from src/runtime/hexagon/target/hexagon_target_log.h rename to src/runtime/hexagon/android/target/hexagon_target_log.h index c7684fc56197..f8ba6a74e3b9 100644 --- a/src/runtime/hexagon/target/hexagon_target_log.h +++ b/src/runtime/hexagon/android/target/hexagon_target_log.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ +#ifndef TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ +#define TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ #ifdef __ANDROID__ #include @@ -31,4 +31,4 @@ #define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) #endif // __ANDROID__ -#endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ +#endif // TVM_RUNTIME_HEXAGON_ANDROID_TARGET_HEXAGON_TARGET_LOG_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon/hexagon_buffer.cc new file mode 100644 index 000000000000..0760bab6c582 --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_buffer.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "hexagon_buffer.h" + +#include + +#include +#include + +#include "hexagon_common.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +static size_t GetDataAlignment(const DLDataType dtype) { + size_t align = (dtype.bits / 8) * dtype.lanes; + if (align < kAllocAlignment) return kAllocAlignment; + return align; +} + +HexagonBuffer::HexagonBuffer(int ndim, const int64_t* shape, DLDataType dtype, + Optional scope) { + ICHECK_LE(ndim, 1) << "Hexagon currently only supports flat allocations " + << "and arrays of flat allocations."; + + size_t alignment = GetDataAlignment(dtype); + // TODO(csullivan): Extend to support arrays of allocations. + // Move assignment from r-value constructed flat allocation. + *this = HexagonBuffer(shape[0] * (dtype.bits / 8) * dtype.lanes, alignment, scope); +} + +HexagonBuffer::HexagonBuffer(size_t nbytes, size_t alignment, Optional scope) { + void* ptr = nullptr; + int ret = posix_memalign(&ptr, alignment, nbytes); + if (ret != 0) { + throw std::bad_alloc(); + } + allocations_.push_back(ptr); + SetStorageScope(scope); +} + +HexagonBuffer::HexagonBuffer(void* data, Optional scope) : managed_{false} { + SetStorageScope(scope); + allocations_.push_back(data); +} + +HexagonBuffer::~HexagonBuffer() { + if (managed_) { + for (auto& ptr : allocations_) { + free(ptr); + } + } +} + +HexagonBuffer::HexagonBuffer(HexagonBuffer&& other) + : allocations_(other.allocations_), + managed_(other.managed_), + storage_scope_(other.storage_scope_) { + other.allocations_.clear(); + other.managed_ = false; + other.storage_scope_ = StorageScope::kDDR; +} + +HexagonBuffer& HexagonBuffer::operator=(HexagonBuffer&& other) { + std::swap(allocations_, other.allocations_); + std::swap(managed_, other.managed_); + std::swap(storage_scope_, other.storage_scope_); + return *this; +} + +void* HexagonBuffer::GetPointer() { + if (!allocations_.size()) { + return nullptr; + } + return (allocations_.size() > 1) ? allocations_.data() : allocations_[0]; +} + +HexagonBuffer::StorageScope HexagonBuffer::GetStorageScope() const { return storage_scope_; } + +void HexagonBuffer::SetStorageScope(Optional scope) { + if (!scope.defined()) { + storage_scope_ = StorageScope::kDDR; + } else { + if (scope.value() == "global") { + storage_scope_ = StorageScope::kDDR; + } else if (scope.value() == "global.vtcm") { + storage_scope_ = StorageScope::kVTCM; + } else { + CHECK(false) << "Encountered unknown HexagonBuffer storage scope: " + << std::string(scope.value()); + } + } +} + +HexagonBuffer* IsHexagonBuffer(DLTensor* tensor) { + if (TVMDeviceExtType(tensor->device.device_type) == kDLHexagon) { + return static_cast(tensor->data); + } + return nullptr; +} + +} // namespace hexagon +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/hexagon/hexagon/hexagon_buffer.h b/src/runtime/hexagon/hexagon/hexagon_buffer.h new file mode 100644 index 000000000000..c62cee66b0d8 --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_buffer.h @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ + +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace runtime { +namespace hexagon { + +class HexagonBuffer { + public: + /* \brief Allocate memory within hexagon accessible memory + * scopes. + * + * \param ndim The number of dimensions of physical storage + * to allocate. + * + * \param shape The shape of the ndarray for which to allocate + * physical storage. + * + * \param dtype The data type of the physical storage. + * + * \param scope Optional storage scope indicating the memory + * space in which to allocate. Defaults to global system + * memory (DDR). + */ + HexagonBuffer(int ndim, const int64_t* shape, DLDataType dtype, Optional scope); + + /* \brief Allocate memory within hexagon accessible memory + * scopes. + * + * \param nbytes The number of bytes of flat physical storage + * to allocate. + * + * \param alignment The byte alignment to be used when allocating. + * + * \param scope Optional storage scope indicating the memory + * space in which to allocate. Defaults to global system + * memory (DDR). + */ + HexagonBuffer(size_t nbytes, size_t alignment, Optional scope); + + /* \brief Construct a hexagon buffer from externally allocated storage. + * + * \param data The externally allocated storage. + * + * \param scope Optional storage scope indicating the memory + * space in the external allocation belongs. Assumes global system + * memory if not provided. + */ + explicit HexagonBuffer(void* data, Optional scope = Optional()); + + //! \brief Destruction deallocates the underlying allocations. + ~HexagonBuffer(); + + //! \brief Prevent copy construction of HexagonBuffers. + HexagonBuffer(const HexagonBuffer&) = delete; + + //! \brief Prevent copy assignment with HexagonBuffers. + HexagonBuffer& operator=(const HexagonBuffer&) = delete; + + //! \brief Allow move construction. + HexagonBuffer(HexagonBuffer&&); + + //! \brief Allow move assignment. + HexagonBuffer& operator=(HexagonBuffer&&); + + //! \brief Return pointer to allocation or allocations. + void* GetPointer(); + + //! \brief Memory scopes managed by a Hexagon Buffer. + enum class StorageScope { + //! \brief System DDR corresponding to global storage. + kDDR, + /*! \brief Vector tightly coupled memory corresponding to + * global.vtcm storage. + */ + kVTCM, + }; + + //! \brief Return storage scope of underlying allocation. + StorageScope GetStorageScope() const; + + private: + //! \brief Assign a storage scope to the buffer. + void SetStorageScope(Optional scope); + /*! \brief Array of allocations required by the buffer. + * + * For a 1d (flat) storage, a single contiguous allocation will + * result. For 2d storage, (count, nbytes) = shape, which will + * result in `count` discrete allocations. + */ + std::vector allocations_; + /*! \brief Whether the allocation(s) present are managed + * and should be deallocated upon destruction. + */ + bool managed_{true}; + /*! \brief The underlying storage type in which the allocation + * resides. + */ + StorageScope storage_scope_; +}; + +HexagonBuffer* IsHexagonBuffer(DLTensor* tensor); + +} // namespace hexagon +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_BUFFER_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon/hexagon_common.cc new file mode 100644 index 000000000000..260b105ac43a --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_common.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hexagon_common.cc + */ + +#include "hexagon_common.h" + +#include +#include + +#include +#include +#include +#include + +#include "hexagon_buffer.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +void HexagonLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { + Module mod = args[0]; + int64_t storage_id = args[1]; + DLTensor* template_tensor = args[2]; + Device dev = args[3]; + auto lookup_linked_param = mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true); + if (lookup_linked_param == nullptr) { + *rv = nullptr; + return; + } + + TVMRetValue opaque_handle = lookup_linked_param(storage_id); + if (opaque_handle.type_code() == kTVMNullptr) { + *rv = nullptr; + return; + } + + std::vector shape_vec{template_tensor->shape, + template_tensor->shape + template_tensor->ndim}; + + auto* param_buffer = new HexagonBuffer(static_cast(opaque_handle)); + auto* container = new NDArray::Container(static_cast(param_buffer), shape_vec, + template_tensor->dtype, dev); + container->SetDeleter([](Object* container) { + // The NDArray::Container needs to be deleted + // along with the HexagonBuffer wrapper. However the + // buffer's data points to global const memory and + // so should not be deleted. + auto* ptr = static_cast(container); + delete static_cast(ptr->dl_tensor.data); + delete ptr; + }); + *rv = NDArray(GetObjectPtr(container)); +} + +PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { + return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + TVMValue ret_value; + int ret_type_code = kTVMNullptr; + + TVMValue* arg_values = const_cast(args.values); + std::vector> buffer_args; + for (size_t i = 0; i < args.num_args; i++) { + if (args.type_codes[i] == kTVMDLTensorHandle) { + DLTensor* tensor = static_cast(arg_values[i].v_handle); + buffer_args.emplace_back(i, static_cast(tensor->data)); + tensor->data = buffer_args.back().second->GetPointer(); + } + } + int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), + args.num_args, &ret_value, &ret_type_code, nullptr); + ICHECK_EQ(ret, 0) << TVMGetLastError(); + + for (auto& arg : buffer_args) { + DLTensor* tensor = static_cast(arg_values[arg.first].v_handle); + tensor->data = arg.second; + } + + if (ret_type_code != kTVMNullptr) { + *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); + } + }); +} +} // namespace hexagon + +namespace { +std::vector SplitString(const std::string& str, char delim) { + std::vector lines; + auto ss = std::stringstream{str}; + for (std::string line; std::getline(ss, line, delim);) { + lines.push_back(line); + } + return lines; +} +void HexagonLog(const std::string& file, int lineno, const std::string& message) { + HEXAGON_PRINT(ALWAYS, "%s:%d:", file.c_str(), lineno); + std::vector err_lines = SplitString(message, '\n'); + for (auto& line : err_lines) { + HEXAGON_PRINT(ALWAYS, "%s", line.c_str()); + } +} +} // namespace + +namespace detail { +void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { + HexagonLog(file, lineno, message); + throw InternalError(file, lineno, message); +} +void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { + HexagonLog(file, lineno, message); +} +} // namespace detail + +TVM_REGISTER_GLOBAL("tvm.runtime.hexagon.lookup_linked_params") + .set_body(hexagon::HexagonLookupLinkedParam); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/hexagon/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon/hexagon_common.h new file mode 100644 index 000000000000..87d36c9865e8 --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_common.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hexagon_utils.h + */ +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ + +#include +#include +#include +#include + +#if defined(__hexagon__) +#include +#define HEXAGON_PRINT(level, ...) FARF(level, __VA_ARGS__) +#else +#include +#define HEXAGON_PRINT(level, ...) printf(__VA_ARGS__) +#endif + +#define HEXAGON_SAFE_CALL(api_call) \ + do { \ + int result = api_call; \ + if (result != 0) { \ + HEXAGON_PRINT(ERROR, "ERROR: " #api_call " failed with error %d.", result); \ + } \ + } while (0) + +namespace tvm { +namespace runtime { +namespace hexagon { + +/*! \brief Unpack HexagonBuffers in packed functions + * prior to invoking. + * \param faddr The function address. + * \param mptr The module pointer node. + * \return A packed function wrapping the requested function. + */ +PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr); +} // namespace hexagon +} // namespace runtime +} // namespace tvm +inline bool IsHexagonDevice(DLDevice dev) { + return TVMDeviceExtType(dev.device_type) == kDLHexagon; +} + +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_COMMON_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc new file mode 100644 index 000000000000..9c1f6ebd7d70 --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hexagon_device_api_v2.cc + */ + +#include "hexagon_device_api_v2.h" + +#include +#include +#include +#include + +#include +#include + +#include "../../workspace_pool.h" +#include "hexagon_buffer.h" +#include "hexagon_common.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +HexagonDeviceAPIv2* HexagonDeviceAPIv2::Global() { + static auto* inst = new HexagonDeviceAPIv2(); + return inst; +} + +void HexagonDeviceAPIv2::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { + if (kind == kExist) { + *rv = 1; + } +} + +void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* shape, + DLDataType dtype, Optional mem_scope) { + return new HexagonBuffer(ndim, shape, dtype, mem_scope.defined() ? mem_scope : String("global")); +} + +void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, + DLDataType type_hint) { + return new HexagonBuffer(nbytes, alignment, String("global")); +} + +void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) { + auto* pbuf = static_cast(ptr); + delete pbuf; +} + +struct HexagonWorkspacePool : public WorkspacePool { + HexagonWorkspacePool() : WorkspacePool(kDLCPU, HexagonDeviceAPIv2::Global()) {} +}; + +void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { + auto* buffer = static_cast( + dmlc::ThreadLocalStore::Get()->AllocWorkspace(dev, size)); + void* ptr = buffer->GetPointer(); + workspace_allocations_.insert({ptr, buffer}); + return ptr; +} + +void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) { + auto it = workspace_allocations_.find(data); + ICHECK(it != workspace_allocations_.end()) + << "Attempt made to free unknown or already freed workspace allocation"; + dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, it->second); + workspace_allocations_.erase(it); +} + +void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + if (IsHexagonDevice(from->device) && IsHexagonDevice(to->device)) { + HexagonBuffer* buffer_src = static_cast(from->data); + HexagonBuffer* buffer_dst = static_cast(to->data); + // Check storage scopes + if (buffer_src->GetStorageScope() == HexagonBuffer::StorageScope::kDDR && + buffer_dst->GetStorageScope() == HexagonBuffer::StorageScope::kDDR) { + memcpy(static_cast(buffer_dst->GetPointer()) + to->byte_offset, + static_cast(buffer_src->GetPointer()) + from->byte_offset, + GetDataSize(*from)); + } else { + ICHECK(false) << "Currently only copying between DDR storage is supported."; + } + } else if (IsHexagonDevice(from->device) && to->device.device_type == kDLCPU) { + HexagonBuffer* buffer_src = static_cast(from->data); + memcpy(static_cast(to->data) + to->byte_offset, + static_cast(buffer_src->GetPointer()) + from->byte_offset, + GetDataSize(*from)); + } else if (from->device.device_type == kDLCPU && IsHexagonDevice(to->device)) { + HexagonBuffer* buffer_dst = static_cast(to->data); + memcpy(static_cast(buffer_dst->GetPointer()) + to->byte_offset, + static_cast(from->data) + from->byte_offset, GetDataSize(*from)); + } else { + CHECK(false) + << "Expect copy between DLTensor devices of types kDLHexagon and kDLCPU (external) only."; + } +} + +void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, Device dev_from, + Device dev_to, DLDataType type_hint, + TVMStreamHandle stream) { + memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); +} + +TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = HexagonDeviceAPIv2::Global(); + *rv = static_cast(ptr); +}); + +} // namespace hexagon +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h new file mode 100644 index 000000000000..3d866307f17c --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ + +#include + +#include + +namespace tvm { +namespace runtime { +namespace hexagon { + +class HexagonBuffer; + +/*! + * \brief Hexagon Device API that is compiled and run on Hexagon. + */ +class HexagonDeviceAPIv2 final : public DeviceAPI { + public: + //! \brief Retrieve the global singleton instance of the HexagonDeviceAPIv2. + static HexagonDeviceAPIv2* Global(); + + //! \brief Constructor + HexagonDeviceAPIv2() {} + + //! \brief Destructor + ~HexagonDeviceAPIv2() {} + + /*! \brief Currently unimplemented interface to specify the active + * Hexagon device. + */ + void SetDevice(Device dev) final{}; + + //! \brief Return the queried Hexagon device attribute. + void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final; + + //! \brief Currently unimplemented interface to synchronize a device stream. + void StreamSync(Device dev, TVMStreamHandle stream) final {} + + //! \note Standard memory allocation methods of the DeviceAPI interface. + //! \brief Allocate a flat allocation of global memory wrapped in a HexagonBuffer. + void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; + + //! \brief Free the allocated HexagonBuffer. + void FreeDataSpace(Device dev, void* ptr) final; + + /*! \brief Request a dynamically allocated HexagonBuffer from a workspace pool. + * \returns The underlying allocation pointer. + */ + void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; + + //! Dereference workspace pool and erase from tracked workspace_allocations_. + void FreeWorkspace(Device dev, void* data) final; + + /*! + * \brief Allocate an Nd data space on device with memory scope support. + * \param dev The device to perform the operation. + * \param ndim The number of dimensions of allocated tensor. + * \param shape The shape of allocated tensor. + * \param dtype The element type. + * \param mem_scope The memory scope of the allocated tensor. + * \return The allocated HexagonBuffer pointer. + */ + void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + Optional mem_scope) final; + + /*! + * \brief Copy data from one storage to another. + * \note This API is designed to support special memory with shape dependent layout. + * DLTensor's are passed with shape information to support these cases. + * \param from The source array. + * \param to The target array. + * \param stream Optional stream object. + */ + void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final; + + protected: + //! Standard Device API interface to copy data from one storage to another. + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + Device dev_from, Device dev_to, DLDataType type_hint, + TVMStreamHandle stream) final; + + private: + //! Lookup table for the HexagonBuffer managing a workspace allocation. + std::unordered_map workspace_allocations_; +}; +} // namespace hexagon +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_HEXAGON_DEVICE_API_V2_H_ diff --git a/src/runtime/hexagon/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon/hexagon_module.cc new file mode 100644 index 000000000000..a4919ce874e2 --- /dev/null +++ b/src/runtime/hexagon/hexagon/hexagon_module.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file hexagon_module.cc + * \brief The HexagonLibraryModuleNode + */ +#include "../hexagon_module.h" + +#include +#include +#include + +#include +#include +#include + +#include "../../library_module.h" +#include "hexagon_buffer.h" +#include "hexagon_common.h" + +namespace tvm { +namespace runtime { + +Module HexagonModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, + const std::set& packed_c_abi) { + CHECK(fmt == "so") << "Invalid format provided when constructing Hexagon runtime module: " << fmt + << ". Valid formats are: 'so'."; + ObjectPtr n = CreateDSOLibraryObject(data); + return CreateModuleFromLibrary(n, hexagon::WrapPackedFunc); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = CreateDSOLibraryObject(args[0]); + *rv = CreateModuleFromLibrary(n, hexagon::WrapPackedFunc); +}); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index 1288b933410c..887d9bb30ecb 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -50,107 +50,6 @@ Module HexagonModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string asm_str, std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi); - -namespace hexagon { - -/*! - * \brief Low-level interface for communicating with Hexagon devices. - */ -class Device { - public: - /*! - * \brief Allocate memory on device. - * \param size Requested size. - * \param align Requested alignment. - * \return Pointer (local to the device) of the allocated memory, - * or nullptr if allocation failed. - */ - virtual void* Alloc(unsigned size, unsigned align) = 0; - /*! - * \brief Release allocated memory on device. - * \param ptr Pointer to memory previously allocated by \ref Alloc. - */ - virtual void Free(void* ptr) = 0; - /*! - * \brief Allocate VTCM memory on device. - * \param size Requested size. - * \param align Requested alignment. - * \return Pointer (local to the device) of the allocated memory, - * or nullptr if allocation failed. - */ - virtual void* AllocVtcm(unsigned size, unsigned align) = 0; - /*! - * \brief Release allocated VTCM memory on device. - * \param ptr Pointer to memory previously allocated by \ref AllocVtcm. - */ - virtual void FreeVtcm(void* ptr) = 0; - /*! - * \brief Copy a block of data on device to another location on the device. - * \param dst Pointer (local to device) to the destination buffer. - * \param src Pointer (local to device) of the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; - /*! - * \brief Copy a block of data from device to host. - * \param host_dst Pointer (local to host) to the destination buffer. - * \param src Pointer (local to device) to the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; - /*! - * \brief Copy a block of data from host to device. - * \param dst Pointer (local to device) to the destination buffer. - * \param host_src Pointer (local to host) to the source buffer. - * \param len Number of bytes to copy. - */ - virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; - /*! - * \brief Load a module (typically a shared library) into device. - * \param data Name of the shared library. - * \param fmt Format of the library (currently ignored). - * \return Pointer to the loaded module. - * \note Currently only one module can be loaded at any given time. - */ - virtual void* Load(const std::string& data, const std::string& fmt) = 0; - /*! - * \brief Unload a module from device. - * \param mod Pointer to a loaded module returned by \ref Load. - */ - virtual void Unload(void* mod) = 0; - /*! - * \brief Find the address of an object in the currently loaded module. - * \param sym Name of the object. - * \return Address of the located object, or nullptr if object was - * not found. - */ - virtual void* Resolve(const std::string& sym) = 0; - /*! - * \brief Invoke a function on device with given arguments. - * \param func Address (local to device) of the function to call. - * \param scalar Pointer to an array of 32-bit values that will be - * passed via consecutive registers: r0..r5. This array - * includes dummy values for skipped registers. - * \param sc_num Number of values in the "scalar" array. - * \param stack Pointer to an array of 32-bit values that will be - * passed on the stack. This array includes dummy values - * for padding. - * \param st_num Number of values in the "stack" array. - */ - virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) = 0; - - virtual ~Device() = 0; - - static std::shared_ptr Global(); - static bool ValidateDeviceId(decltype(DLDevice::device_id) device_id) { - // Only supporting a single device for now. - return device_id == 0; - } -}; - -} // namespace hexagon - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 5dfd5e8ad7d5..7efa91d912eb 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -37,7 +37,8 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} + explicit LibraryModuleNode(ObjectPtr lib, PackedFuncWrapper wrapper) + : lib_(lib), packed_func_wrapper_(wrapper) {} const char* type_key() const final { return "library"; } @@ -53,11 +54,12 @@ class LibraryModuleNode final : public ModuleNode { faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); } if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + return packed_func_wrapper_(faddr, sptr_to_self); } private: ObjectPtr lib_; + PackedFuncWrapper packed_func_wrapper_; }; /*! @@ -128,7 +130,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) { * \param root_module the output root module * \param dso_ctx_addr the output dso module */ -void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Module* root_module, +void ProcessModuleBlob(const char* mblob, ObjectPtr lib, + PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module, runtime::ModuleNode** dso_ctx_addr = nullptr) { ICHECK(mblob != nullptr); uint64_t nbytes = 0; @@ -152,7 +155,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // "_lib" serves as a placeholder in the module import tree to indicate where // to place the DSOModule if (tkey == "_lib") { - auto dso_module = Module(make_object(lib)); + auto dso_module = Module(make_object(lib, packed_func_wrapper)); *dso_ctx_addr = dso_module.operator->(); ++num_dso_module; modules.emplace_back(dso_module); @@ -170,7 +173,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul // if we are using old dll, we don't have import tree // so that we can't reconstruct module relationship using import tree if (import_tree_row_ptr.empty()) { - auto n = make_object(lib); + auto n = make_object(lib, packed_func_wrapper); auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); for (const auto& m : modules) { module_import_addr->emplace_back(m); @@ -194,9 +197,9 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr lib, runtime::Modul } } -Module CreateModuleFromLibrary(ObjectPtr lib) { +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper packed_func_wrapper) { InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); - auto n = make_object(lib); + auto n = make_object(lib, packed_func_wrapper); // Load the imported modules const char* dev_mblob = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); @@ -204,7 +207,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { Module root_mod; runtime::ModuleNode* dso_ctx_addr = nullptr; if (dev_mblob != nullptr) { - ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr); + ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr); } else { // Only have one single DSO Module root_mod = Module(n); @@ -218,5 +221,10 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { return root_mod; } + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = CreateDSOLibraryObject(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 00c79e8248f4..b5780975f43a 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -78,16 +78,35 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& */ void InitContextFunctions(std::function fgetsymbol); +/*! + * \brief Type alias for funcion to wrap a TVMBackendPackedCFunc. + * \param The function address imported from a module. + * \param mptr The module pointer node. + * \return Packed function that wraps the invocation of the function at faddr. + */ +using PackedFuncWrapper = + std::function& mptr)>; + +/*! \brief Return a library object interface over dynamic shared + * libraries in Windows and Linux providing support for + * loading/unloading and symbol lookup. + * \param Full path to shared library. + * \return Returns pointer to the Library providing symbol lookup. + */ +ObjectPtr CreateDSOLibraryObject(std::string library_path); + /*! * \brief Create a module from a library. * * \param lib The library. + * \param wrapper Optional function used to wrap a TVMBackendPackedCFunc, + * by default WrapPackedFunc is used. * \return The corresponding loaded module. * * \note This function can create multiple linked modules * by parsing the binary blob section of the library. */ -Module CreateModuleFromLibrary(ObjectPtr lib); +Module CreateModuleFromLibrary(ObjectPtr lib, PackedFuncWrapper wrapper = WrapPackedFunc); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/micro/crt_config.h b/src/runtime/micro/crt_config.h index c3e8fea1ba08..602060de1b4a 100644 --- a/src/runtime/micro/crt_config.h +++ b/src/runtime/micro/crt_config.h @@ -37,7 +37,9 @@ /*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ #define TVM_CRT_MAX_STRLEN_DLTYPE 10 /*! Maximum supported string length in function names */ -#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 120 +/*! Maximum supported string length in parameter names */ +#define TVM_CRT_MAX_STRLEN_PARAM_NAME 80 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 3cd5df613f4a..4e24434642d8 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -41,7 +41,7 @@ namespace runtime { struct TypeInfo { /*! \brief The current index. */ uint32_t index{0}; - /*! \brief Index of the parent in the type hierachy */ + /*! \brief Index of the parent in the type hierarchy */ uint32_t parent_index{0}; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. @@ -58,7 +58,7 @@ struct TypeInfo { }; /*! - * \brief Type context that manages the type hierachy information. + * \brief Type context that manages the type hierarchy information. */ class TypeContext { public: diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 26eddb40a7d5..f12a143ab0cc 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -196,6 +196,10 @@ void* OpenCLWorkspace::AllocDataSpace(Device dev, size_t size, size_t alignment, ICHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; cl::BufferDescriptor* desc = new cl::BufferDescriptor; + // CL_INVALID_BUFFER_SIZE if size is 0. + if (size == 0) { + size = 1; + } desc->buffer = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); desc->layout = cl::BufferDescriptor::MemoryLayout::kBuffer1D; OPENCL_CHECK_ERROR(err_code); diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index a1d06fc8cab8..90d4ac64238f 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -30,6 +30,7 @@ #include #include +#include #include #include #include @@ -342,7 +343,7 @@ String ReportNode::AsJSON() const { return s.str(); } -String ReportNode::AsTable(bool sort, bool aggregate) const { +String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) const { // aggregate calls by op hash (or op name if hash is not set) + argument shapes std::vector> aggregated_calls; if (aggregate) { @@ -414,36 +415,38 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { } // compute columnwise sums - std::unordered_map col_sums; - for (auto call : aggregated_calls) { - for (auto p : call) { - if (p.second.as()) { - int64_t val = p.second.as()->value; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->value; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->microseconds; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->microseconds; - } - col_sums[p.first] = ObjectRef(make_object(val)); - } else if (p.second.as()) { - double val = p.second.as()->percent; - auto it = col_sums.find(p.first); - if (it != col_sums.end()) { - val += it->second.as()->percent; + if (compute_col_sums) { + std::unordered_map col_sums; + for (auto call : aggregated_calls) { + for (auto p : call) { + if (p.second.as()) { + int64_t val = p.second.as()->value; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->value; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->microseconds; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->microseconds; + } + col_sums[p.first] = ObjectRef(make_object(val)); + } else if (p.second.as()) { + double val = p.second.as()->percent; + auto it = col_sums.find(p.first); + if (it != col_sums.end()) { + val += it->second.as()->percent; + } + col_sums[p.first] = ObjectRef(make_object(val)); } - col_sums[p.first] = ObjectRef(make_object(val)); } } + col_sums["Name"] = String("Sum"); + aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator + aggregated_calls.push_back(col_sums); } - col_sums["Name"] = String("Sum"); - aggregated_calls.push_back({{String("Name"), String("----------")}}); // separator - aggregated_calls.push_back(col_sums); // per-device metrics for (auto p : device_metrics) { @@ -454,7 +457,6 @@ String ReportNode::AsTable(bool sort, bool aggregate) const { // Table formatting std::set unique_headers; - for (auto row : aggregated_calls) { for (auto p : row) { unique_headers.insert(p.first); @@ -666,6 +668,7 @@ TVM_REGISTER_OBJECT_TYPE(ReportNode); TVM_REGISTER_OBJECT_TYPE(DeviceWrapperNode); TVM_REGISTER_OBJECT_TYPE(MetricCollectorNode); +TVM_REGISTER_GLOBAL("runtime.profiling.AsTable").set_body_method(&ReportNode::AsTable); TVM_REGISTER_GLOBAL("runtime.profiling.AsCSV").set_body_typed([](Report n) { return n->AsCSV(); }); TVM_REGISTER_GLOBAL("runtime.profiling.AsJSON").set_body_typed([](Report n) { return n->AsJSON(); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index a5e7d253f3cd..4d7ee457e1e6 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -61,6 +61,8 @@ PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtrGetBytecode(); }); + } else if (name == "get_constants") { + return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetConstants(); }); } else if (name == "get_stats") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { @@ -146,6 +148,34 @@ std::string Executable::GetBytecode() const { return oss.str(); } +namespace { +String ShapeString(const ShapeTuple& shape_tuple, DLDataType dtype) { + std::stringstream sizes; + sizes << DLDataType2String(dtype) << "["; + for (size_t i = 0; i < shape_tuple.size(); i++) { + if (i != 0) { + sizes << ", "; + } + sizes << shape_tuple.data()[i]; + } + sizes << "]"; + return String(sizes.str()); +} +} // namespace + +std::string Executable::GetConstants() const { + std::ostringstream oss; + + for (size_t i = 0; i < constants.size(); ++i) { + const auto& constant = constants[i]; + auto ndarray = Downcast(constant); + DLDeviceType device_type = static_cast(const_device_type[i]); + oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) + << " on device of type " << device_type << std::endl; + } + return oss.str(); +} + std::string Executable::Stats() const { std::ostringstream oss; oss << "Relay VM executable statistics:" << std::endl; @@ -308,7 +338,7 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { VMInstructionSerializer SerializeInstruction(const Instruction& instr) { std::vector fields; // Save the opcode. - DLOG(INFO) << "Serializing: " << instr << std::endl; + VLOG(1) << "Serializing: " << instr << std::endl; switch (instr.op) { case Opcode::Move: { // Number of fields = 2 diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 410f6c2a042d..22afcce6a01e 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -119,14 +119,14 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) { std::unique_ptr alloc; switch (type) { case kNaive: { - DLOG(INFO) << "New naive allocator for " << DeviceName(dev.device_type) << "(" - << dev.device_id << ")"; + VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id + << ")"; alloc.reset(new NaiveAllocator(dev)); break; } case kPooled: { - DLOG(INFO) << "New pooled allocator for " << DeviceName(dev.device_type) << "(" - << dev.device_id << ")"; + VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "(" + << dev.device_id << ")"; alloc.reset(new PooledAllocator(dev)); break; } diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index c282eb006f92..e5f236983a73 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -67,7 +67,7 @@ class PooledAllocator final : public Allocator { } used_memory_.fetch_add(size, std::memory_order_relaxed); - DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B"; return buf; } @@ -77,7 +77,7 @@ class PooledAllocator final : public Allocator { memory_pool_.emplace(buffer.size, std::vector{}); } memory_pool_.at(buffer.size).push_back(buffer); - DLOG(INFO) << "reclaim buffer " << buffer.size; + VLOG(1) << "reclaim buffer " << buffer.size; } size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } @@ -93,7 +93,7 @@ class PooledAllocator final : public Allocator { } memory_pool_.clear(); used_memory_ = 0; - DLOG(INFO) << "release all buffers"; + VLOG(1) << "release all buffers"; } private: diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h index cbcdb1bdfa16..b4a10806caaf 100644 --- a/src/runtime/vm/serialize_utils.h +++ b/src/runtime/vm/serialize_utils.h @@ -59,13 +59,13 @@ struct VMFunctionSerializer { /*! \brief The parameters of the VMFunction. */ std::vector params; /*! \brief The device type of each parameter of the VMFunction. */ - std::vector params_device_type; + std::vector params_device_type; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params, - const std::vector& params_device_type) + const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index addd5ca5d861..b903f793d799 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -236,7 +236,7 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) { << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = offset; i < args.size(); ++i) { - DLDeviceType device_type = vm_func.params_device_type[i - offset]; + Index device_type = vm_func.params_device_type[i - offset]; Device dev = GetDevice(device_type); if (args[i].type_code() == kTVMDLTensorHandle) { @@ -284,20 +284,20 @@ Index VirtualMachine::PopFrame() { } void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { - DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); + VLOG(2) << "Invoking global " << func.name << " " << args.size(); PushFrame(func.params.size(), this->pc_ + 1, func); for (size_t i = 0; i < args.size(); ++i) { WriteRegister(i, args[i]); } - DLOG(INFO) << "func.params= " << func.params.size(); + VLOG(2) << "func.params= " << func.params.size(); code_ = func.instructions.data(); pc_ = 0; } ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { - DLOG(INFO) << "Executing Function: " << std::endl << func; + VLOG(2) << "Executing Function: " << std::endl << func; InvokeGlobal(func, args); RunLoop(); @@ -309,7 +309,7 @@ ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vectorglobal_map.find(name); ICHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable"; auto func_index_ = it->second; - DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_; + VLOG(2) << "Invoke Global " << name << " at index " << func_index_; return Invoke(exec_->functions[func_index_], args); } @@ -445,7 +445,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = code_[this->pc_]; - DLOG(INFO) << "Executing(" << pc_ << "): " << instr; + VLOG(2) << "Executing(" << pc_ << "): " << instr; switch (instr.op) { case Opcode::Move: { @@ -500,13 +500,13 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::InvokePacked: { - DLOG(INFO) << "InvokedPacked " << instr.packed_index << " arity=" << instr.arity; + VLOG(2) << "InvokedPacked " << instr.packed_index << " arity=" << instr.arity; ICHECK_LE(instr.packed_index, packed_funcs_.size()); const auto& func = packed_funcs_[instr.packed_index]; const auto& arity = instr.arity; std::vector args; for (Index i = 0; i < arity; ++i) { - DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i]; + VLOG(2) << "arg" << i << " $" << instr.packed_args[i]; auto arg = ReadRegister(instr.packed_args[i]); args.push_back(arg); } @@ -579,6 +579,18 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor.storage); auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto storage = Downcast(storage_obj); +#if TVM_LOG_DEBUG + std::ostringstream os; + os << "AllocTensor: "; + os << "offset=" << offset; + os << ", shape=["; + for (auto i : shape) { + os << i << ","; + } + os << "]"; + os << ", dtype=" << DLDataType2String(instr.alloc_tensor.dtype); + VLOG(2) << os.str(); +#endif auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); @@ -625,17 +637,15 @@ void VirtualMachine::RunLoop() { OpStartHook(instr); auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; - - DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment - << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) - << ", device_type=" << instr.alloc_storage.device_type; - auto storage_obj = SimpleObjAllocator().make_object(); auto dev_type = instr.alloc_storage.device_type; ICHECK_LT(static_cast(dev_type), allocators_.size()) << "Memory allocator for device " << dev_type << " has not been initialized"; auto* alloc = allocators_[dev_type]; ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; + VLOG(2) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment + << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) + << ", device_type=" << instr.alloc_storage.device_type; storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); diff --git a/src/support/utils.h b/src/support/utils.h index d8e3bf5f30ab..3bb870022214 100644 --- a/src/support/utils.h +++ b/src/support/utils.h @@ -145,7 +145,7 @@ inline bool StartsWith(const String& str, const char* prefix) { if (str.data()[i] != prefix[i]) return false; } // return true if the str is equal to the prefix - return prefix[n + 1] == '\0'; + return prefix[n] == '\0'; } /*! diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc new file mode 100644 index 000000000000..9797b6751af7 --- /dev/null +++ b/src/target/compilation_config.cc @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/compilation_config.cc + * \brief Implementation of \p CompilationConfig for collecting \p Targets. + */ + +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(CompilationConfigNode); + +void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { + v->Visit("legacy_target_map", &legacy_target_map); + v->Visit("host_target", &host_target); + v->Visit("primitive_targets", &primitive_targets); + v->Visit("default_primitive_se_scope", &default_primitive_se_scope); + v->Visit("host_se_scope", &host_se_scope); + v->Visit("optional_homogenous_target", &optional_homogeneous_target); + // NOTE: The se_scope_cache_ is not accessible via FFI. +} + +SEScope CompilationConfigNode::CanonicalSEScope(const SEScope& se_scope) const { + if (se_scope->target.defined()) { + return se_scope_cache_.Unique(se_scope); + } + DLDeviceType device_type = se_scope->device_type(); + // TODO(mbs): Proper diagnostics. + CHECK(device_type != kInvalidDeviceType) + << "SEScope annotations must include at least a device_type"; + Target target = FindPrimitiveTargetOrFail(se_scope->device_type()); + return se_scope_cache_.Unique( + SEScope(device_type, se_scope->virtual_device_id, target, se_scope->memory_scope)); +} + +void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContext& pass_ctx) { + // + // Gather the hints as to what our default device type for the 'host' should be, and + // create an appropriate target if we don't already have one. + // + DLDeviceType host_device_type; + if (host_target.defined()) { + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + DLOG(INFO) << "Using the given host target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; + for (const auto& primitive_target : primitive_targets) { + if (primitive_target->host.defined() && + !StructuralEqual()(primitive_target->host, host_target)) { + DLOG(WARNING) << "The primitive target " << primitive_target->ToDebugString() + << " already has a host which disagrees with the desired host target. It " + << "will be ignored."; + } + } + } else if (primitive_targets.size() == 1 && primitive_targets.front()->host.defined()) { + host_target = primitive_targets.front()->GetHost().value(); + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + DLOG(INFO) << "Using the host of the unique primitive target, namely " + << host_target->ToDebugString() << " of device type " << host_device_type + << " for the host target"; + } else if (primitive_targets.size() == 1 && + primitive_targets.front()->kind->device_type == kDLCPU) { + // In the homogenous case without an explicit host target just use the given target so long as + // it's a CPU. + host_device_type = kDLCPU; + host_target = primitive_targets.front(); + DLOG(INFO) << "Using the unique primitive target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; + } else { + // Fallback. + host_device_type = kDLCPU; + // Even if the list of available targets already includes one for kDLCPU we won't use it + // in the hetrogeneous case since its options may not be appropriate for host code + // (eg shape functions). Instead, create a fresh default Target. + host_target = MakeDefaultTarget(host_device_type); + DLOG(WARNING) << "Using the default target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; + } + ICHECK(host_target.defined()); + ICHECK(!host_target->host.defined()); + + if (host_device_type != kDLCPU) { + // I think we're on thin ice here until we've audited the code base for assumed kDLCPU. + LOG(WARNING) << "The host target is not a CPU."; + } + + // + // Establish the host SEScope. + // + host_se_scope = se_scope_cache_.Unique(SEScope(host_device_type, + /*virtual_device_id=*/0, host_target)); + + // + // Now that we've settled on a host, make sure all the primitive Targets agree on it for + // their 'host' field. This mutates the primitives. + // + Array new_primitve_targets; + new_primitve_targets.reserve(primitive_targets.size()); + for (const auto& primitive_target : primitive_targets) { + new_primitve_targets.push_back(Target(primitive_target, host_target)); + } + primitive_targets = new_primitve_targets; + + // + // Gather the hints as to what our default device type for primitives should be. + // + DLDeviceType default_primitive_device_type; + Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); + if (opt_fallback_dev) { + const int64_t v = opt_fallback_dev.value()->value; + CHECK_GT(v, 0) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v; + default_primitive_device_type = static_cast(v); + DLOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; + } else if (primitive_targets.size() == 1) { + // In the homogeneous case there's no free choice. + default_primitive_device_type = + static_cast(primitive_targets.front()->kind->device_type); + DLOG(INFO) << "Using the device type " << default_primitive_device_type + << " of the unique primitive target as the default device type for all primitive " + << "operations"; + } else { + // Fallback. Note that we'll require a primitive Target of kDLCPU device_type to be given + // and won't manufacture one out of thin air. + default_primitive_device_type = kDLCPU; + DLOG(WARNING) << "Using " << default_primitive_device_type + << " as the default device type for all primitive operations"; + } + + // + // Establish the default primitive SEScope, choosing a known Target to match the device type. + // + default_primitive_se_scope = se_scope_cache_.Unique( + SEScope(default_primitive_device_type, + /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); +} + +/* static */ Target CompilationConfigNode::MakeDefaultTarget(DLDeviceType device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") { + if (runtime::Registry::Get("codegen.LLVMModuleCreate")) { + // LLVM is available. + return Target("llvm"); + } else { + // LLVM is not available. + // TODO(mbs): Already deprecated? + return Target("stackvm"); + } + } else { + return Target(name); + } +} + +Target CompilationConfigNode::FindPrimitiveTargetOrFail(DLDeviceType device_type) const { + auto itr = std::find_if( + primitive_targets.begin(), primitive_targets.end(), + [device_type](const Target& target) { return target->kind->device_type == device_type; }); + CHECK(itr != primitive_targets.end()) << "No target for device type " << device_type << " in the " + << primitive_targets.size() << " given by the targets list"; + return *itr; +} + +CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, + TargetMap legacy_target_map_arg, + Target optional_host_target_arg) { + VLOG_CONTEXT << "CompilationConfig"; + + auto node = make_object(); + + for (const auto& pair : legacy_target_map_arg) { + VLOG(0) << "Available primitive target " << pair.first << " = " << pair.second->ToDebugString(); + } + if (optional_host_target_arg.defined()) { + VLOG(0) << "Available host target " << optional_host_target_arg->ToDebugString(); + } + + // Capture the arguments in our representation. + for (const auto& pair : legacy_target_map_arg) { + node->primitive_targets.push_back(pair.second); + } + node->host_target = optional_host_target_arg; + + // Complete the targets vector and establish default scopes. After this primitive_targets will + // contain the definitive list of all required targets, target_host will be defined, and + // all primitive targets will have host target_host. + node->EstablishDefaultSEScopes(pass_ctx); + + // LEGACY: Reconstruct the target map with all the primitive targets. + for (const auto& primitive_target : node->primitive_targets) { + node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); + } + + ICHECK(node->default_primitive_se_scope->target.defined()); + ICHECK(node->host_se_scope->target.defined()); + ICHECK_GT(node->primitive_targets.size(), 0U); + + // Legacy: Some passes only support homogenous compilation and expect the target to be + // given by the global target context. Make this easy to detect. + node->optional_homogeneous_target = + node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); + + for (const auto& target : node->primitive_targets) { + DLOG(INFO) << "Target " << target->ToDebugString() << " of device type " + << target->kind->device_type << " is available for primitives"; + } + DLOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; + DLOG(INFO) << "Using host scope " << node->host_se_scope; + + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("target.MakeCompilationConfig") + .set_body_typed([](const transform::PassContext& pass_ctx, TargetMap legacy_target_map, + Target optional_host_target) -> CompilationConfig { + return CompilationConfig(pass_ctx, std::move(legacy_target_map), + std::move(optional_host_target)); + }); + +} // namespace tvm diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc new file mode 100644 index 000000000000..95d5a7de5775 --- /dev/null +++ b/src/target/se_scope.cc @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/se_scope.cc + * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + */ +#include +#include +#include + +namespace tvm { + +TVM_REGISTER_NODE_TYPE(SEScopeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "SEScope("; + if (node->IsFullyUnconstrained()) { + p->stream << "?"; + } else { + bool need_sep = false; + if (node->device_type() != kInvalidDeviceType) { + p->stream << "device_type=" << node->device_type(); + need_sep = true; + } + if (node->virtual_device_id >= 0) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "virtual_device_id=" << node->virtual_device_id; + need_sep = true; + } + if (node->target.defined()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "target=" << node->target->ToDebugString(); + need_sep = true; + } + if (!node->memory_scope.empty()) { + if (need_sep) { + p->stream << ", "; + } + p->stream << "memory_scope='" << node->memory_scope << "'"; + } + } +#if TVM_LOG_DEBUG + // We rely on object identity of SEScopes, so include the object address to help debugging. + p->stream << ", id=" << reinterpret_cast(ref.get()); +#endif + p->stream << ")"; + }); + +SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + ICHECK(!target.defined() || device_type == target->kind->device_type) + << "target " << target->ToDebugString() << " has device type " << target->kind->device_type + << " but scope has device type " << device_type; + auto node = make_object(); + node->device_type_int = device_type; + node->virtual_device_id = virtual_device_id; + node->target = std::move(target); + node->memory_scope = std::move(memory_scope); + data_ = std::move(node); +} + +/* static */ SEScope SEScope::FullyUnconstrained() { + static const SEScope unconstrained{}; + return unconstrained; +} + +/* static */ +Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType joined_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + joined_device_type = lhs->device_type(); + if (rhs->device_type() != kInvalidDeviceType && lhs->device_type() != rhs->device_type()) { + return {}; + } + } else { + joined_device_type = rhs->device_type(); + } + int joined_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + joined_virtual_device_id = lhs->virtual_device_id; + if (rhs->virtual_device_id >= 0 && lhs->virtual_device_id != rhs->virtual_device_id) { + return {}; + } + } else { + joined_virtual_device_id = rhs->virtual_device_id; + } + Target joined_target; + if (lhs->target.defined()) { + joined_target = lhs->target; + if (rhs->target.defined() && lhs->target != rhs->target) { + return {}; + } + } else { + joined_target = rhs->target; + } + MemoryScope joined_memory_scope; + if (!lhs->memory_scope.empty()) { + joined_memory_scope = lhs->memory_scope; + if (!rhs->memory_scope.empty() && lhs->memory_scope != rhs->memory_scope) { + return {}; + } + } else { + joined_memory_scope = rhs->memory_scope; + } + return SEScope(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope); +} + +/* static */ +SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { + if (lhs == rhs) { + return lhs; + } + DLDeviceType defaulted_device_type; + if (lhs->device_type() != kInvalidDeviceType) { + defaulted_device_type = lhs->device_type(); + } else { + defaulted_device_type = rhs->device_type(); + } + int defaulted_virtual_device_id; + if (lhs->virtual_device_id >= 0) { + defaulted_virtual_device_id = lhs->virtual_device_id; + } else { + defaulted_virtual_device_id = rhs->virtual_device_id; + } + Target defaulted_target; + if (lhs->target.defined()) { + defaulted_target = lhs->target; + } else { + // We can only default to the rhs's target if it is consistent with the device type + if (rhs->target.defined() && rhs->target->kind->device_type == defaulted_device_type) { + defaulted_target = rhs->target; + } + // else: leave as null + } + MemoryScope defaulted_memory_scope; + if (!lhs->memory_scope.empty()) { + defaulted_memory_scope = lhs->memory_scope; + } else { + defaulted_memory_scope = rhs->memory_scope; + } + return SEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); +} + +SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { + SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); + auto itr = cache_.find(prototype); + if (itr == cache_.end()) { + VLOG(1) << "added new scope " << prototype; + cache_.emplace(prototype); + return prototype; + } else { + VLOG(1) << "reusing existing scope " << *itr; + ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); + if (prototype->target.defined()) { + ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); + } + return *itr; + } +} + +SEScope SEScopeCache::Unique(const SEScope& scope) { + return Make(scope->device_type(), scope->virtual_device_id, scope->target, scope->memory_scope); +} + +TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") + .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); + +} // namespace tvm diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0aad18ffb6f9..a52564c34a68 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,7 +525,7 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as()->value; if (sync == "warp") { // DO nothing. - } else if (sync == "shared") { + } else if (sync == "shared" || sync == "shared.dyn") { this->PrintIndent(); this->stream << "__syncthreads();\n"; } else if (sync == "global") { diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index d93a7fde639a..507a6243cb0c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -478,6 +478,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N } } +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, CodeGenOpenCL* p) { + if (op->dtype.lanes() == 1) { + os << opstr << "(("; + p->PrintType(op->a->dtype, os); + os << ")"; + p->PrintExpr(op->a, os); + os << ", ("; + p->PrintType(op->b->dtype, os); + os << ")"; + p->PrintExpr(op->b, os); + os << ')'; + } else { + p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + } +} + +void CodeGenOpenCL::VisitExpr_(const MinNode* op, std::ostream& os) { + PrintBinaryExpr(op, "min", os, this); +} + +void CodeGenOpenCL::VisitExpr_(const MaxNode* op, std::ostream& os) { + PrintBinaryExpr(op, "max", os, this); +} + void CodeGenOpenCL::SetTextureScope( const std::unordered_map& scope) { // NOLINT(*) for (auto& texture : scope) { diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index a8c293c03056..8c36a817753c 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -65,6 +65,10 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const StoreNode* op) final; // NOLINT(*) + // overload min and max to avoid ambiguous call errors + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; + private: // whether enable fp16 and fp64 extension bool enable_fp16_{false}; diff --git a/src/target/target.cc b/src/target/target.cc index e0b9539380d7..6f5e8ee67b30 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -74,7 +74,7 @@ void CheckAndUpdateHostConsistency(Target* target, Target* host) { *host = (*target)->GetHost().value_or(Target()); } -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { +void CheckAndUpdateHostConsistency(TargetMap* targets, Target* host) { Map new_targets; for (auto& it : *targets) { auto target = it.second; @@ -457,6 +457,7 @@ const std::string& TargetNode::str() const { if (Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } + str_repr_ = os.str(); } return str_repr_; @@ -531,6 +532,61 @@ Optional TargetNode::GetHost() const { return GetRef>(this->host.as()); } +String TargetNode::ToDebugString() const { + std::ostringstream os; + os << "Target("; + os << "kind='" << kind->name << "'"; + if (!tag.empty()) { + os << ", tag='" << tag << "'"; + } + if (!keys.empty()) { + os << ", keys={"; + bool first = true; + for (const auto& key : keys) { + if (!first) { + os << ", "; + } + os << "'" << key << "'"; + first = false; + } + os << "}"; + } + if (!attrs.empty()) { + os << ", attrs={"; + bool first = true; + for (const auto& pair : attrs) { + if (!first) { + os << ", "; + } + os << '"' << pair.first << "': " << pair.second; + first = false; + } + os << "}"; + } + if (host.defined()) { + os << ", host=" << GetHost().value()->ToDebugString(); + } +#if TVM_LOG_DEBUG + // We depend on pointer equality so include that in the debug representation. + os << ", id=" << reinterpret_cast(this); +#endif + os << ")"; + return os.str(); +} + +bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { + return equal(kind.get(), other->kind.get()) && equal(host, other->host) && + equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); +} + +void TargetNode::SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(kind.get()); + hash_reduce(host); + hash_reduce(tag); + hash_reduce(keys); + hash_reduce(attrs); +} + /*! \brief Entry to hold the Target context stack. */ struct TVMTargetThreadLocalEntry { /*! \brief The current target context */ diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 657dc121961c..d90681a1c0db 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -22,6 +22,7 @@ #include #include +#include #include "../schedule/graph.h" @@ -300,9 +301,40 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return (*complete)(func, info.root_alloc); } // namespace tir -TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed([](const Array& tensors) { - return CreatePrimFunc(tensors); -}); +PrimFunc CreatePrimFuncFromOutputs(const Array& outputs) { + std::vector stack; + std::unordered_set visited; + for (const te::Tensor& output : outputs) { + if (!visited.count(output.get())) { + visited.insert(output.get()); + stack.push_back(output); + } + } + + Array arg_list; + while (!stack.empty()) { + te::Tensor tensor = stack.back(); + stack.pop_back(); + if (tensor->op->IsInstance()) { + arg_list.push_back(tensor); + } else if (tensor->op->IsInstance()) { + Array inputs = tensor->op->InputTensors(); + for (const te::Tensor& input : inputs) { + if (!visited.count(input.get())) { + visited.insert(input.get()); + stack.push_back(input); + } + } + } + } + for (const te::Tensor& output : outputs) { + arg_list.push_back(output); + } + return CreatePrimFunc(arg_list); +} + +TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); +TVM_REGISTER_GLOBAL("te.CreatePrimFuncFromOutputs").set_body_typed(CreatePrimFuncFromOutputs); } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 90aaa35d60d8..776538adbc0f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -97,12 +97,14 @@ class BlockReadWriteDetector : public StmtExprVisitor { void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const CallNode* op) override; }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { @@ -154,6 +156,38 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { dom_map_.erase(op->loop_var.get()); } +void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { + VisitExpr(op->condition); + { + // Visit then branch + With ctx(op->condition, &dom_map_, true); + StmtExprVisitor::VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + // Visit else branch + With ctx(op->condition, &dom_map_, false); + StmtExprVisitor::VisitStmt(op->else_case); + } +} + +void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::if_then_else())) { + VisitExpr(op->args[0]); + { + // Visit then branch + With ctx(op->args[0], &dom_map_, true); + StmtExprVisitor::VisitExpr(op->args[1]); + } + { + // Visit else branch + With ctx(op->args[0], &dom_map_, false); + StmtExprVisitor::VisitExpr(op->args[2]); + } + return; + } + StmtExprVisitor::VisitExpr_(op); +} + void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index efffa9031ac0..dc1ed1c193e8 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -198,12 +198,12 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const StoreNode* op) { - if (op->index->dtype.lanes() > 1) { - if (static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) > + if (op->value->dtype.lanes() > 1) { + if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; - s << "Number of lanes (" << op->index->dtype.lanes() << ") times number of bytes (" - << op->index->dtype.bytes() << ") for dtype " << op->index->dtype + s << "Number of lanes (" << op->value->dtype.lanes() << ") times number of bytes (" + << op->value->dtype.bytes() << ") for dtype " << op->value->dtype << " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")"; errors_.push_back(s.str()); } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index afc5c36ebb92..1d7c959d990d 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -90,6 +90,18 @@ Var Var::copy_with_suffix(const String& suffix) const { return Var(new_ptr); } +Var Var::copy_with_dtype(DataType dtype) const { + const VarNode* node = get(); + ObjectPtr new_ptr; + if (auto* ptr = this->as()) { + new_ptr = make_object(*ptr); + } else { + new_ptr = make_object(*node); + } + new_ptr->dtype = std::move(dtype); + return Var(new_ptr); +} + TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type, Span span) { if (type.IsObjectRef()) { @@ -904,6 +916,35 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // CommReducer CommReducer::CommReducer(Array lhs, Array rhs, Array result, Array identity_element, Span span) { + size_t n_group = result.size(); + CHECK_EQ(lhs.size(), n_group) << "ValueError: The number of vars in `lhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(rhs.size(), n_group) << "ValueError: The number of vars in `rhs` must equal to the " + "number of elements in `results`"; + CHECK_EQ(identity_element.size(), n_group) + << "ValueError: The number of identities must equal to the number of elements in `results`"; + + // Change the dtype of input vars to adapt to the dtype of identities + ArrayNode* p_lhs = lhs.CopyOnWrite(); + ArrayNode* p_rhs = rhs.CopyOnWrite(); + std::unordered_map var_map; + var_map.reserve(n_group * 2); + for (int i = 0; i < static_cast(n_group); ++i) { + DataType dtype = identity_element[i].dtype(); + Var l = lhs[i].copy_with_dtype(dtype); + Var r = rhs[i].copy_with_dtype(dtype); + var_map[lhs[i].get()] = l; + var_map[rhs[i].get()] = r; + + p_lhs->SetItem(i, l); + p_rhs->SetItem(i, r); + } + + ArrayNode* p_result = result.CopyOnWrite(); + for (int i = 0; i < static_cast(n_group); ++i) { + p_result->SetItem(i, Substitute(result[i], var_map)); + } + auto node = make_object(); node->lhs = lhs; node->rhs = rhs; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d14d64a4c787..e3a535e9b3d4 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -505,8 +505,8 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, if (const ForNode* loop = p->StmtAs()) { if (loop->kind == ForKind::kThreadBinding) { const String& thread_tag = loop->thread_binding.value()->thread_tag; - if (CanRelaxStorageUndereThread(extra_relax_scope, - runtime::ThreadScope::Create(thread_tag))) { + if (CanRelaxStorageUnderThread(extra_relax_scope, + runtime::ThreadScope::Create(thread_tag))) { result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 42839075af30..4db4cd4ba1c8 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -232,6 +232,16 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, throw; } +Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, + max_innermost_factor, &decision)); + TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -282,6 +292,38 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } +Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +Array ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::GetProducers(state_, this->GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_); + throw; +} + +Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::GetConsumers(state_, this->GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_); + throw; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 1f9aeecfc776..035c16f506cf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -21,6 +21,7 @@ #include #include +#include #include "./utils.h" @@ -80,19 +81,17 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) override; + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; + Array GetChildBlocks(const BlockRV& block_rv) override; + Array GetChildBlocks(const LoopRV& loop_rv) override; + Array GetProducers(const BlockRV& block_rv) override; + Array GetConsumers(const BlockRV& block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; @@ -154,6 +153,12 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variable created */ inline ExprRV CreateRV(int64_t value); + /*! + * \brief Add a list of integers as random variables into the symbol table + * \param value The list of integers to be added to the symbol table + * \return The new random variables created + */ + inline Array CreateRV(const std::vector& value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -274,6 +279,15 @@ inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { return std::move(rv); } +inline Array ConcreteScheduleNode::CreateRV(const std::vector& value) { + Array results; + results.reserve(value.size()); + for (int64_t v : value) { + results.push_back(CreateRV(v)); + } + return results; +} + inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index d8dcf57b91e4..4ce5a97bb5d3 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -24,29 +24,37 @@ namespace tir { String ScheduleError::RenderReport(const String& primitive) const { IRModule mod = this->mod(); std::ostringstream os; - os << "ScheduleError: An error occurred in the schedule primitive '" << primitive - << "'.\n\nThe IR is:\n" - << AsTVMScript(mod); + + // get locations of interest Array locs = LocationsOfInterest(); + std::unordered_map loc_obj_to_name; int n_locs = locs.size(); - std::vector roi_names; - roi_names.reserve(n_locs); - if (n_locs > 0) { - os << "Regions of interest:\n"; - for (const ObjectRef& obj : locs) { - String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); - os << name << "\n" << obj; - roi_names.emplace_back(std::move(name)); - } - os << "\n"; - } std::string msg = DetailRenderTemplate(); - for (int i = 0; i < n_locs; ++i) { - std::string src = "{" + std::to_string(i) + "}"; - for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { - msg.replace(pos, src.length(), roi_names[i]); + if (n_locs > 0) { + for (int i = 0; i < n_locs; ++i) { + std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), name); + } + loc_obj_to_name.emplace(locs[i], std::move(name)); } } + + // print IR module + runtime::TypedPackedFunc annotate = + runtime::TypedPackedFunc( + [&loc_obj_to_name](const Stmt& expr) -> std::string { + auto it = loc_obj_to_name.find(Downcast(expr)); + if (it == loc_obj_to_name.end()) return ""; + return it->second; + }); + + os << "ScheduleError: An error occurred in the schedule primitive '" << primitive + << "'.\n\nThe IR with diagnostic is:\n" + << AsTVMScriptWithDiagnostic(mod, "T", false, annotate); + + // print error message os << "Error message: " << msg; return os.str(); } diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 057e845dbd48..cc7e44d4df9e 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,6 +22,8 @@ #include #include +#include + namespace tvm { namespace tir { @@ -32,11 +34,10 @@ namespace tir { * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ -TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive); +TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t min_inclusive, int32_t max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. - * \param self The schedule to update * \param rand_state The pointer to schedule's random state * \param candidates The candidates * \param probs The probability distribution of the candidates @@ -46,6 +47,40 @@ TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Sample the factors to perfect tile a specific loop + * \param rand_state The random state + * \param extent The loop extent to be tiled + * \param n_split The number of tiles to be sampled + * \return A list of length `n`, the random perfect tile sizes sampled + */ +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_splits); +/*! + * \brief Sample the factors to perfect tile a specific loop + * \param rand_state The random state + * \param extent The loop extent to be tiled + * \param n_split The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \return A list of length `n`, the random perfect tile sizes sampled + */ +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + int32_t extent, int32_t n_split, int32_t max_innermost_factor); +/*! + * \brief Sample the factors to perfect tile a specific loop + * \param rand_state The random state + * \param loop_sref The loop to be tiled + * \param n_split The number of tiles to be sampled + * \param max_innermost_factor The maximum tile size allowed to be sampled in the innermost loop + * \param decision The sampling decision + * \return A list of length `n`, the random perfect tile sizes sampled + */ +TVM_DLL std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, + Optional>* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -63,6 +98,27 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); +/*! + * \brief Get the leaf blocks of a specific block/loop + * \param self The schedule state + * \param parent_sref The query block/loop + * \return A list of leaf blocks inside a specific block/loop + */ +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +/*! + * \brief Get the producers of a specific block + * \param self The schedule state + * \param block_sref The block in the query + * \return A list of blocks, the producers of the given block + */ +Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the consumers of a specific block + * \param self The schedule state + * \param block_rv The block in the query + * \return A list of blocks, the consumers of the given block + */ +Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 008d47792f69..55869e12b6b2 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -121,6 +121,11 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind runtime::ThreadScope thread_scope) { PreOrderVisit(loop, [&](const ObjectRef& node) { if (const auto* realize = node.as()) { + // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block + // inside `tir.init()`. We don't check the condition for such blocks. + if (!self->stmt2ref.count(realize->block.get())) { + return false; + } CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef(realize), thread_scope); } diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 8b32a9c14f58..c044de3bc644 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -55,6 +55,56 @@ Array GetLoops(const StmtSRef& block_sref) { return {result.rbegin(), result.rend()}; } +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { + struct Collector : public StmtVisitor { + private: + void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } + + public: + explicit Collector(const ScheduleState& self) : self(self) {} + + const ScheduleState& self; + Array result; + }; + Collector collector(self); + if (parent_sref->stmt->IsInstance()) { + const auto* loop = static_cast(parent_sref->stmt); + collector(loop->body); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); + collector(block->body); + } + return std::move(collector.result); +} + +Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_stage_pipeline=*/false); + Array edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref); + Array results; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { + results.push_back(edge->src); + } + } + return results; +} + +Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_stage_pipeline=*/false); + Array edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref); + Array results; + results.reserve(edges.size()); + for (const Dependency& edge : edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { + results.push_back(edge->dst); + } + } + return results; +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -106,8 +156,90 @@ struct GetLoopsTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetChildBlocksTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetChildBlocks"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(block)); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(loop)); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, String block_or_loop_rv) { + PythonAPICall py("get_child_blocks"); + py.Input("", block_or_loop_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct GetProducersTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetProducers"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetProducers(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_producers"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct GetConsumersTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetConsumers"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetConsumers(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_consumers"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6ac6226118cd..171838572dbb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,26 +24,151 @@ namespace tvm { namespace tir { -int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, - int max_exclusive) { +struct PrimeTable { + /*! \brief The table contains prime numbers in [2, kMaxPrime) */ + static constexpr const int32_t kMaxPrime = 65536; + /*! \brief The exact number of prime numbers in the table */ + static constexpr const int32_t kNumPrimes = 6542; + /*! + * \brief For each number in [2, kMaxPrime), the index of its min factor. + * For example, if min_factor_idx[x] = i, then the min factor of x is primes[i]. + */ + int32_t min_factor_idx[kMaxPrime]; + /*! \brief The prime numbers in [2, kMaxPrime) */ + std::vector primes; + /*! + * \brief The power of each prime number. + * pow_table[i, j] stores the result of pow(prime[i], j + 1) + */ + std::vector> pow_tab; + + /*! \brief Get a global instance of the prime table */ + static const PrimeTable* Global() { + static const PrimeTable table; + return &table; + } + + /*! \brief Constructor, pre-computes all info in the prime table */ + PrimeTable() { + constexpr const int64_t int_max = std::numeric_limits::max(); + // Euler's sieve: prime number in linear time + for (int32_t i = 0; i < kMaxPrime; ++i) { + min_factor_idx[i] = -1; + } + primes.reserve(kNumPrimes); + for (int32_t x = 2; x < kMaxPrime; ++x) { + if (min_factor_idx[x] == -1) { + min_factor_idx[x] = primes.size(); + primes.push_back(x); + } + for (size_t i = 0; i < primes.size(); ++i) { + int64_t factor = primes[i]; + int64_t y = x * factor; + if (y >= kMaxPrime) { + break; + } + min_factor_idx[y] = i; + if (x % factor == 0) { + break; + } + } + } + ICHECK_EQ(static_cast(primes.size()), static_cast(kNumPrimes)); + // Calculate the power table for each prime number + pow_tab.reserve(primes.size()); + for (int32_t prime : primes) { + std::vector tab; + tab.reserve(32); + for (int64_t pow = prime; pow <= int_max; pow *= prime) { + tab.push_back(pow); + } + tab.shrink_to_fit(); + pow_tab.emplace_back(std::move(tab)); + } + } + /*! + * \brief Factorize a number n, and return in a cryptic format + * \param n The number to be factorized + * \return A list of integer pairs [(i_1, j_1), (i_2, j_2), ..., (i_l, j_l)] + * For each pair (i, j), we define + * (a, b) = (j, 1) if i == -1 (in this case j must be a prime number) + * (primes[i], j) if i != -1 + * Then the factorization is + * n = (a_1 ^ b_1) * (a_2 ^ b_2) ... (a_l ^ b_l) + */ + std::vector> Factorize(int32_t n) const { + std::vector> result; + result.reserve(16); + int32_t i = 0, n_primes = primes.size(); + // Phase 1: n >= kMaxPrime + for (int32_t j; n >= kMaxPrime && i < n_primes && primes[i] * primes[i] <= n; ++i) { + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + if (j != 0) { + result.emplace_back(i, j); + } + } + // if i >= n_primes or primes[i] > sqrt(n), then n must be a prime number + if (n >= kMaxPrime) { + result.emplace_back(-1, n); + return result; + } + // Phase 2: n < kMaxPrime + for (int32_t j; n > 1;) { + int32_t i = min_factor_idx[n]; + for (j = 0; n % primes[i] == 0; n /= primes[i], ++j) { + } + result.emplace_back(i, j); + } + return result; + } +}; + +int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, + int32_t max_exclusive) { CHECK(min_inclusive < max_exclusive) << "ValueError: max_exclusive must be greater than min_inclusive."; if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } support::LinearCongruentialEngine rand_(rand_state); - std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); return dist(rand_); } +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k) { + if (k == 1) { + return {SampleInt(rand_state, 0, n)}; + } + if (k == 2) { + int32_t result0 = SampleInt(rand_state, 0, n); + int32_t result1 = SampleInt(rand_state, 0, n - 1); + if (result1 >= result0) { + result1 += 1; + } + return {result0, result1}; + } + std::vector order(n); + for (int32_t i = 0; i < n; ++i) { + order[i] = i; + } + for (int32_t i = 0; i < k; ++i) { + int32_t j = SampleInt(rand_state, i, n); + if (i != j) { + std::swap(order[i], order[j]); + } + } + return {order.begin(), order.begin() + k}; +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; - int i = -1; - int n = candidates.size(); - + int32_t i = -1; + int32_t n = candidates.size(); if (decision->defined()) { const auto* int_imm = decision->as(); i = int_imm->value; @@ -51,7 +176,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } else { std::vector weights = support::AsVector(probs); - std::discrete_distribution dist(weights.begin(), weights.end()); + std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n @@ -62,6 +187,151 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits) { + CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; + CHECK_GE(n_splits, 1) << "ValueError: Cannot tile a loop to 0 or negative splits"; + // Handle special case that we can potentially accelerate + if (n_splits == 1) { + return {extent}; + } + if (extent == 1) { + return std::vector(n_splits, 1); + } + // Enumerate each pair (i, j), we define + // (a, p) = (j, 1) if i == -1 (in this case j must be a prime number) + // (primes[i], j) if i != -1 + // Then the factorization is + // extent = (a_1 ^ p_1) * (a_2 ^ p_2) ... (a_l ^ p_l) + const PrimeTable* prime_tab = PrimeTable::Global(); + std::vector> factorized = prime_tab->Factorize(extent); + if (n_splits == 2) { + // n_splits = 2, this can be taken special care of, + // because general reservoir sampling can be avoided to accelerate the sampling + int32_t result0 = 1; + int32_t result1 = 1; + for (const std::pair& ij : factorized) { + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + (SampleInt(rand_state, 0, 2) ? result1 : result0) *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + int32_t x1 = SampleInt(rand_state, 0, p + 1); + int32_t x2 = p - x1; + if (x1 != 0) { + result0 *= pow[x1]; + } + if (x2 != 0) { + result1 *= pow[x2]; + } + } + return {result0, result1}; + } + // Data range: + // 2 <= extent <= 2^31 - 1 + // 3 <= n_splits <= max tiling splits + // 1 <= p <= 31 + std::vector result(n_splits, 1); + for (const std::pair& ij : factorized) { + // Handle special cases to accelerate sampling + // Case 1: (a, p) = (j, 1), where j is a prime number + if (ij.first == -1) { + result[SampleInt(rand_state, 0, n_splits)] *= ij.second; + continue; + } + // Case 2: (a = primes[i], p = 1) + int32_t p = ij.second; + if (p == 1) { + result[SampleInt(rand_state, 0, n_splits)] *= prime_tab->primes[ij.first]; + continue; + } + // The general case. We have to sample uniformly from the solution of: + // x_1 + x_2 + ... + x_{n_splits} = p + // where x_i >= 0 + // Data range: + // 2 <= p <= 31 + // 3 <= n_splits <= max tiling splits + std::vector sampled = + SampleWithoutReplacement(rand_state, p + n_splits - 1, n_splits - 1); + std::sort(sampled.begin(), sampled.end()); + sampled.push_back(p + n_splits - 1); + const int32_t* pow = prime_tab->pow_tab[ij.first].data() - 1; + for (int32_t i = 0, last = -1; i < n_splits; ++i) { + int32_t x = sampled[i] - last - 1; + last = sampled[i]; + if (x != 0) { + result[i] *= pow[x]; + } + } + } + return result; +} + +std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, + int32_t extent, int32_t n_splits, + int32_t max_innermost_factor) { + if (max_innermost_factor == -1) { + return SamplePerfectTile(rand_state, extent, n_splits); + } + CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int32_t i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); + } + } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); + result.push_back(innermost); + return result; +} + +std::vector SamplePerfectTile( + support::LinearCongruentialEngine::TRandState* rand_state, // + const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + Optional>* decision) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + int64_t extent = GetLoopIntExtent(loop); + std::vector result; + if (extent == -1) { + // Case 1. Handle loops with non-constant length + result = std::vector(n_splits, 1); + result[0] = -1; + } else if (decision->defined()) { + // Case 2. Use previous decision + result = support::AsVector(decision->value()); + int n = result.size(); + ICHECK_GE(n, 2); + int64_t len = extent; + for (int i = n - 1; i > 0; --i) { + int64_t& l = result[i]; + // A previous decision could become invalid because of the change of outer tiles + // To handle this case properly, we check if the tiling strategy is still perfect. + // If not, we use a trivial default solution (1, 1, ..., 1, L) for rest of the tiles + if (len % l != 0) { + l = len; + } + len /= l; + } + result[0] = len; + } else { + // Case 3. Use fresh new sampling result + result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + ICHECK_LE(result.back(), max_innermost_factor); + } + *decision = support::AsArray(result); + return result; +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -96,7 +366,38 @@ struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SamplePerfectTile"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static Array UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Integer n, + Integer max_innermost_factor, + Optional> decision) { + return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, Integer n, + Integer max_innermost_factor, Optional> decision) { + PythonAPICall py("sample_perfect_tile"); + py.Input("loop", loop_rv); + py.Input("n", n->value); + py.Input("max_innermost_factor", max_innermost_factor->value); + py.Decision(decision); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); +TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 84a37c392e81..a411e40b13b6 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,11 +123,29 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") /******** (FFI) Sampling ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") + .set_body_method(&ScheduleNode::SamplePerfectTile); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") + .set_body_typed([](Schedule self, ObjectRef rv) { + if (const auto* block_rv = rv.as()) { + return self->GetChildBlocks(GetRef(block_rv)); + } + if (const auto* loop_rv = rv.as()) { + return self->GetChildBlocks(GetRef(loop_rv)); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") + .set_body_method(&ScheduleNode::GetProducers); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") + .set_body_method(&ScheduleNode::GetConsumers); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index cc48f2b9e7ce..4a028d1dad5c 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -43,6 +43,7 @@ Schedule TracedScheduleNode::Copy() const { } /******** Schedule: Sampling ********/ + ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { @@ -57,6 +58,21 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, return result; } +Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, + int max_innermost_factor, + Optional> decision) { + Array results = CreateRV(tir::SamplePerfectTile( + &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{Integer(n), Integer(max_innermost_factor)}, + /*outputs=*/{results.begin(), results.end()}), + /*decision=*/decision); + return results; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -81,6 +97,50 @@ Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } +Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetProducers(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetProducers(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetProducers"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetConsumers(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetConsumers"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index fae5ca8608dd..ac36b9ca06a9 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,20 +47,17 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Sample an integer given the probability distribution - * \param candidates The candidates - * \param probs The probability distribution of the candidates - * \param decision The sampling decision, if it's given we would validate the decision, otherwise - * we would sample a decision from the distribution and set the decision accordingly. - * \return The random variable sampled from candidates - */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) final; - + Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, + Optional> decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; + Array GetChildBlocks(const BlockRV& block_rv) final; + Array GetChildBlocks(const LoopRV& loop_rv) final; + Array GetProducers(const BlockRV& block_rv) final; + Array GetConsumers(const BlockRV& block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index a63a9f079617..c66c2ca76693 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -53,8 +53,8 @@ namespace tir { * \brief A helper macro to convert an sref to the statement it points to, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param SRef The SRef to be casted - * \param Type The type to be casted to, can be Block or For + * \param SRef The SRef to be cast + * \param Type The type to be cast to, can be Block or For */ #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ SRef->StmtAs(); \ @@ -64,7 +64,7 @@ namespace tir { * \brief A helper macro to convert an sref to the block it points to, * throwing an internal error if downcasting fails * \param Result The result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_BLOCK(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \ @@ -75,7 +75,7 @@ namespace tir { * \brief A helper macro to convert an sref to the for-loop it points to, * throwing an internal error if downcasting fails * \param Result The name of the result variable, used for checking - * \param SRef The SRef to be casted + * \param SRef The SRef to be cast */ #define TVM_SREF_TO_FOR(Result, SRef) \ TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \ @@ -86,8 +86,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * then check if the downcasting succeeded. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ From.as(); \ @@ -97,8 +97,8 @@ namespace tir { * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as`, * throwing an internal error if downcast fails. * \param Result The result variable, used for checking - * \param From The ObjectRef to be downcasted - * \param Type The type to be downcasted to + * \param From The ObjectRef to be downcast + * \param Type The type to be downcast to */ #define TVM_TYPE_AS(Result, From, Type) \ TVM_TYPE_AS_OR_ERR(Result, From, Type) \ @@ -129,8 +129,8 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { * \param thread_scope The thread scope to be relaxed * \return A boolean indicating the result */ -inline bool CanRelaxStorageUndereThread(const runtime::StorageScope& storage_scope, - const runtime::ThreadScope& thread_scope) { +inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scope, + const runtime::ThreadScope& thread_scope) { if (storage_scope.rank == runtime::StorageRank::kWarp) { // for warp memory, we only relax threadIdx.x return thread_scope.rank == 1 && thread_scope.dim_index == 0; @@ -210,6 +210,28 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } +/**************** Loop extents ****************/ + +/*! + * \brief Get the extents of a loop + * \param loop The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const ForNode* loop) { + const auto* int_extent = loop->extent.as(); + return int_extent ? int_extent->value : -1; +} + +/*! + * \brief Get the extents of a loop + * \param loop_sref The loop to be queried + * \return The extents of the loop + */ +inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + return GetLoopIntExtent(loop); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a1f488f386b3..07f977860d93 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -41,16 +42,19 @@ namespace tir { using support::NDIntSet; /*! - * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the - * int_set is empty. + * \brief simplify and return the region collected by NDIntSet. return the original + * buffer shape if the int_set is empty. */ -Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, - const Array& original_shape) { +Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, + const Array& original_shape, + arith::Analyzer* analyzer) { Array result; result.reserve(nd_int_set.size()); for (size_t i = 0; i < nd_int_set.size(); ++i) { const arith::IntSet& int_set = nd_int_set[i]; - result.push_back(int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]))); + Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])); + result.push_back( + Range::FromMinExtent(analyzer->Simplify(range->min), analyzer->Simplify(range->extent))); } return result; } @@ -85,6 +89,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const BufferStoreNode* op) final { VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + VisitExpr(op->value); } void VisitExpr_(const BufferLoadNode* op) final { @@ -105,58 +110,91 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); + Range loop_range = Range::FromMinExtent(op->min, op->extent); + dom_analyzer_.Bind(op->loop_var, loop_range); + dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range)); StmtExprVisitor::VisitStmt_(op); + dom_map_.erase(op->loop_var.get()); ancestor_loops_.pop_back(); - // The iter_dom_map is updated by post DFS order. - // If the union point is under the for node, the loop var will not be relaxed. - // If the union point is outer of the for loop, the loop var should be relaxed. - iter_dom_map_on_post_order_[op->loop_var.get()] = - arith::IntSet::FromMinExtent(op->min, op->extent); + } + + void VisitStmt_(const IfThenElseNode* op) final { + // Visit condition + StmtExprVisitor::VisitExpr(op->condition); + { + // Visit then branch + With ctx(op->condition, &dom_map_, true); + StmtExprVisitor::VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + // Visit else branch + With ctx(op->condition, &dom_map_, false); + StmtExprVisitor::VisitStmt(op->else_case); + } + } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::if_then_else())) { + // Visit condition + StmtExprVisitor::VisitExpr(op->args[0]); + { + // Visit then branch + With ctx(op->args[0], &dom_map_, true); + StmtExprVisitor::VisitExpr(op->args[1]); + } + { + // Visit else branch + With ctx(op->args[0], &dom_map_, false); + StmtExprVisitor::VisitExpr(op->args[2]); + } + return; + } + return StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BlockNode* op) final { // Step 0. Check there is no init part. ICHECK(!op->init.defined()); - // Step 1. Update outer buffer access info using buffer region + // Step 1. Record and update current read/write region annotations + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + cur_access_annotations; for (const BufferRegion& region : op->reads) { - VisitBufferAccess(region); + cur_access_annotations[region->buffer].push_back(region); } for (const BufferRegion& region : op->writes) { - VisitBufferAccess(region); + cur_access_annotations[region->buffer].push_back(region); } - - // Step 2. Update inner buffer - // Step 2.1. rebuild map buffer_var_in_scope - std::unordered_map buffer_var_in_scope; + for (auto& p : cur_access_annotations) { + auto& regions = access_annotations_[p.first]; + p.second.swap(regions); + } + // Step 2. Record relax position of ancestor_loops_ into buffer_var_in_scope_ for (const Buffer& buffer : op->alloc_buffers) { - buffer_var_in_scope.emplace(buffer->data, buffer); + buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, ancestor_loops_.size())); } - // Step 2.2 Record top stack element before recursive visiting. - size_t stack_top = buffer_access_stack_.size(); - - // Step 2.3. Update the buffer_var_in_scope_ of visitor and visit recursively - std::swap(buffer_var_in_scope, buffer_var_in_scope_); + // Step 3. Visit match buffers + for (const MatchBufferRegion& region : op->match_buffers) { + VisitBufferAccess(region->source); + } + // Step 4. Visit block body recursively StmtExprVisitor::VisitStmt_(op); - std::swap(buffer_var_in_scope, buffer_var_in_scope_); - - // Step 2.4. Combine and relax access - std::unordered_map relaxed_region = - CombineAndRelax(stack_top); - - // Step 2.5. Visit ancestor_loops and try to relax outer thread loops. + // Step 5. Recover read/write region annotations + for (auto& p : cur_access_annotations) { + auto& regions = access_annotations_[p.first]; + if (p.second.empty()) { + access_annotations_.erase(p.first); + } else { + regions.swap(p.second); + } + } + // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { - auto it = relaxed_region.find(buffer); - ICHECK(it != relaxed_region.end()); + auto it = relaxed_accesses_.find(buffer); + ICHECK(it != relaxed_accesses_.end()) + << buffer << " is allocated but not accessed within block scope"; const NDIntSet& nd_int_set = it->second; - std::unordered_map dom_map; - for (const ForNode* loop : ancestor_loops_) { - const VarNode* loop_var = loop->loop_var.get(); - if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { - dom_map[loop_var] = arith::IntSet::FromMinExtent(loop->min, loop->extent); - } - } - NDIntSet int_set = support::NDIntSetEval(nd_int_set, dom_map); - buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape); + buffer_access_region_[buffer] = + SimplifyAndNarrowBufferRegionFromNDIntSet(nd_int_set, buffer->shape, &dom_analyzer_); } } @@ -166,61 +204,54 @@ class BufferAccessRegionCollector : public StmtExprVisitor { const BufferNode* buffer = buffer_region->buffer.get(); auto it = buffer_var_in_scope_.find(buffer->data); if (it != buffer_var_in_scope_.end()) { - const Buffer& buffer = it->second; - const BufferAccessInfo* info = - arena_.make(buffer, support::NDIntSetFromRegion(buffer_region->region)); - buffer_access_stack_.push(info); + const Buffer& buffer = it->second.first; + size_t n_ancestor_loops = it->second.second; + NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region); + // Step 1. Stop ancestor loop vars out of the allocation block from + // being relaxed unless NeedRelaxThread() is true. + std::vector non_relaxed(n_ancestor_loops); + for (size_t i = 0; i < n_ancestor_loops; ++i) { + const ForNode* loop = ancestor_loops_[i]; + const VarNode* v = loop->loop_var.get(); + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { + continue; + } + auto dom_it = dom_map_.find(v); + ICHECK(dom_it != dom_map_.end()); + non_relaxed[i] = dom_it->second; + dom_map_.erase(dom_it); + } + // Step 2. Relax the access region + nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_); + // Step 3. Restore the non-relaxed ancestor loops domain + for (size_t i = 0; i < n_ancestor_loops; ++i) { + const VarNode* v = ancestor_loops_[i]->loop_var.get(); + dom_map_.emplace(v, non_relaxed[i]); + } + // Step 4. Update relaxed_accesses_ dict + auto access_it = relaxed_accesses_.find(buffer); + if (access_it != relaxed_accesses_.end()) { + support::NDIntSetUnionWith(&access_it->second, nd_int_set); + } else { + relaxed_accesses_.insert(access_it, {buffer, nd_int_set}); + } } } void VisitBufferVar(const Var& var) { auto it = buffer_var_in_scope_.find(var); if (it != buffer_var_in_scope_.end()) { - const Buffer& buffer = it->second; - VisitBufferAccess(BufferRegion::FullRegion(buffer)); - } - } - - /*! - * \brief Combine buffer accesses in the sub-tree. - * \details The access info is stored in a stack by DFS order, so that the accesses in the - * sub-tree are top-n elements in the stack. - * \param stack_top compact the access information in `stack[stack_top:end]`. - */ - std::unordered_map CombineAndRelax( - size_t stack_top) { - std::unordered_map accesses; - while (buffer_access_stack_.size() > stack_top) { - const BufferAccessInfo* info = buffer_access_stack_.top(); - buffer_access_stack_.pop(); - NDIntSet nd_int_set = - support::NDIntSetEval(info->accessed_region, iter_dom_map_on_post_order_); - auto it = accesses.find(info->buffer); - if (it != accesses.end()) { - support::NDIntSetUnionWith(&it->second, nd_int_set); + const Buffer& buffer = it->second.first; + auto annotation_it = access_annotations_.find(buffer); + if (annotation_it != access_annotations_.end()) { + // opaque buffer has explicit accessed region annotations + for (const BufferRegion& region : annotation_it->second) { + VisitBufferAccess(region); + } } else { - accesses[info->buffer] = nd_int_set; + VisitBufferAccess(BufferRegion::FullRegion(buffer)); } } - return accesses; - } - - /*! - * \brief Combine buffer accesses in the sub-tree and push the combined result into the stack. - * \details The access info is stored in a stack by DFS order, so that the accesses in the - * sub-tree are top-n elements in the stack. - * \param stack_top The top element of the stack before visiting the sub-tree. - */ - std::unordered_map CombineRelaxAndPushStack( - size_t stack_top) { - std::unordered_map accesses = - CombineAndRelax(stack_top); - for (const auto& kv : accesses) { - const Buffer& buffer = kv.first; - const NDIntSet& int_set = kv.second; - buffer_access_stack_.push(arena_.make(buffer, int_set)); - } - return accesses; } /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */ @@ -232,23 +263,34 @@ class BufferAccessRegionCollector : public StmtExprVisitor { const String& thread_tag = loop->thread_binding.value()->thread_tag; // When there is warp memory // threadIdx.x must be set to be warp index. - return CanRelaxStorageUndereThread(scope, runtime::ThreadScope::Create(thread_tag)); + return CanRelaxStorageUnderThread(scope, runtime::ThreadScope::Create(thread_tag)); } /**************** Class members ****************/ - - /*! \brief Buffer access in DFS order. */ - std::stack buffer_access_stack_; /*! \brief The loops from the current node up to the root. */ std::vector ancestor_loops_; - /*! \brief The vars of the buffer allocated under the current block. */ - std::unordered_map buffer_var_in_scope_; + + /*! + * \brief The vars of the buffer allocated under the current block. + * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where + * n_ancester_loop is the loop num out of the current block. + * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when + * we evaluate this buffer's access regions. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_var_in_scope_; + /*! \brief The map from loop vars to their iter range. */ - std::unordered_map iter_dom_map_on_post_order_; + std::unordered_map dom_map_; + /*! \brief The analyzer aware of loop domains. */ + arith::Analyzer dom_analyzer_; + /*! \brief The map from Buffer to it's relaxed access set. */ + std::unordered_map relaxed_accesses_; /*! \brief The map from Buffer to it entire access region, used for returning. */ std::unordered_map buffer_access_region_; - /*! \brief Internal arena. */ - support::Arena arena_; + /*! \brief The map from Buffer to it's access regions annotated by current block. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + access_annotations_; }; /*! \brief Collect storage alignment information from block annotations. */ diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 262906ade2e8..2423b09d4fb7 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -24,6 +24,7 @@ #include "ir_utils.h" #include +#include #include #include @@ -251,5 +252,88 @@ Bool IsFromLegacyTESchedule(PrimFunc f) { return from_legacy_te_schedule.value(); } +Map ConditionalBoundsContext::GetVarBoundsFromCondition() { + // extract equations and related vars from condition expression. + // currently only extract simple integral equations which could be solvable. + arith::Analyzer analyzer; + PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); + Array equations; + std::unordered_set var_set; + std::function fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) { + if (e->IsInstance() || e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || e->IsInstance()) { + bool is_simple = true; + std::vector cand_vars; + PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) { + if (obj.same_as(e)) { + return; + } else if (const VarNode* var = obj.as()) { + if (var->dtype.is_int() || var->dtype.is_uint()) { + cand_vars.push_back(GetRef(var)); + } + } else { + is_simple &= obj->IsInstance() || obj->IsInstance() || + obj->IsInstance() || obj->IsInstance() || + obj->IsInstance() || obj->IsInstance(); + } + }); + if (is_simple && !cand_vars.empty()) { + for (const Var& var : cand_vars) var_set.insert(var); + equations.push_back(Downcast(e)); + } + } else if (e->IsInstance()) { + And op = Downcast(e); + fvisit(op->a); + fvisit(op->b); + } else if (e->IsInstance()) { + Call op = Downcast(e); + if (op->op.same_as(builtin::likely())) { + fvisit(op->args[0]); + } + } + }; + fvisit(condition); + if (equations.empty() || var_set.empty()) { + return Map(); + } + // build dom ranges for related vars + Array vars = Array(var_set.begin(), var_set.end()); + Map ranges; + for (const Var& v : vars) { + auto it = dom_map_->find(v.get()); + if (it != dom_map_->end()) { + const auto& int_set = it->second; + ranges.Set(v, Range::FromMinExtent(int_set.min(), + analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + } + // solve constraints + arith::IntConstraints constraint(vars, ranges, equations); + auto result = arith::SolveInequalitiesToRange(constraint); + return result->ranges; +} + +ConditionalBoundsContext::ConditionalBoundsContext( + const PrimExpr& condition, std::unordered_map* dom_map, + bool is_true_branch) + : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {} + +void ConditionalBoundsContext::EnterWithScope() { + for (const auto& p : GetVarBoundsFromCondition()) { + const auto* var = p.first.get(); + auto it = dom_map_->find(var); + if (it != dom_map_->end()) { + origin_map_.emplace(var, it->second); + it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)}); + } + } +} + +void ConditionalBoundsContext::ExitWithScope() { + for (const auto& p : origin_map_) { + (*dom_map_)[p.first] = p.second; + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 9be18b790b41..7b1d34c8162d 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -24,7 +24,9 @@ #ifndef TVM_TIR_TRANSFORMS_IR_UTILS_H_ #define TVM_TIR_TRANSFORMS_IR_UTILS_H_ +#include #include +#include #include #include #include @@ -32,6 +34,7 @@ #include #include +#include #include namespace tvm { @@ -224,6 +227,42 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region */ Bool IsFromLegacyTESchedule(PrimFunc f); +/*! + *\brief Context helper to update domain map within conditional scope. + * + * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is + *[0, 8]. Then `With ctx(&dom_map, bounds, true)` step into scope where + *dom_map[i] is [0, 8] and `With ctx(&dom_map, bounds, false)` step into + *scope where dom_map[i] is [9, 20] + */ +class ConditionalBoundsContext { + private: + friend class With; + /*! + * \brief Construct a condition bounds context. + * \param condition The condition holds on true branch. + * \param dom_map The global domain map to be updated. + * \param is_true_branch Whether step into the branch where condition bounds holds. + */ + ConditionalBoundsContext(const PrimExpr& condition, + std::unordered_map* dom_map, + bool is_true_branch); + void EnterWithScope(); + void ExitWithScope(); + + /*! \brief Helper to solve related variable's bound within conditional scope.*/ + Map GetVarBoundsFromCondition(); + + /*! \brief the condition holds on true branch. */ + const PrimExpr& condition_; + /*! \brief global domain map to updated */ + std::unordered_map* dom_map_; + /*! \brief whether is on true branch */ + bool is_true_branch_; + /*! \brief used to record and restore original var bounds */ + std::unordered_map origin_map_; +}; + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index e8865b260dc1..f3ff1f37a5da 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -31,51 +31,256 @@ #include #include "../../runtime/thread_storage_scope.h" +#include "../../support/arena.h" #include "ir_utils.h" namespace tvm { namespace tir { +using runtime::StorageRank; +using runtime::StorageScope; + bool IsDynamicSharedMemory(Var buffer_var) { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; } +/*! + * \brief collect the mapping from the buffer var to its allocate + */ class AllocateCollector : public StmtExprVisitor { public: void VisitStmt_(const AllocateNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - dyn_shmem_allocs_.insert(op); + dyn_shmem_allocs_[op->buffer_var.get()] = op; + } + StmtExprVisitor::VisitStmt_(op); + } + // The mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; +}; + +// Find a linear pattern of storage access +// Used for liveness analysis. +// "linear" means fitting a complex access pattern into an array of StmtEntry +// +// Define "scope" as the body of For/thread_launch/IfThenElse +// Composite scopes(loop/thread_launch/IfThen) is represented by three StmtEntry: +// before_scope -> scope_body -> after_scope +// +// This pass tries to detect last point that we need to keep memory +// alive under the same scope as Allocate. +// The storage need to be kept alive between Allocate and last access. +// The free point is only inserted at the same scope of Allocate. +// +class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { + public: + /*! \brief record the touch list of statement. */ + struct StmtEntry { + // The statement + const Object* stmt; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + offset + // if offset < 0, means this is the end, the begin entry is current_index + offset + int64_t scope_pair_offset{0}; + // The buffer variables this statement touched. + std::vector touched; + }; + // The scope of each allocation + struct AllocEntry { + // the level in the scope stack + size_t level{0}; + // allocation stmt + const AllocateNode* alloc{nullptr}; + }; + + void VisitStmt_(const AllocateNode* op) final { + size_t level = scope_.size(); + const VarNode* buf = op->buffer_var.get(); + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const StoreNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + // Add write access. + const VarNode* buf = op->buffer_var.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + void VisitStmt_(const EvaluateNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + void VisitExpr_(const LoadNode* op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode* buf = op->buffer_var.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::address_of())) { + const LoadNode* l = op->args[0].as(); + this->VisitExpr(l->index); + } else { + StmtExprVisitor::VisitExpr_(op); } + } + void VisitExpr_(const VarNode* buf) final { + // Directly reference to the variable count as a read. + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + template + void VisitNewScope(const T* op) { + scope_.push_back(StmtEntry()); + StmtEntry e; + e.stmt = op; + int64_t begin_index = static_cast(linear_seq_.size()); + // before scope. + linear_seq_.push_back(e); StmtExprVisitor::VisitStmt_(op); + // after scope. + e.touched = std::move(scope_.back().touched); + scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + ICHECK_GT(end_index, begin_index); + e.scope_pair_offset = begin_index - end_index; + linear_seq_.push_back(e); + // record the pointer to end index. + ICHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } + void VisitStmt_(const AttrStmtNode* op) final { + // Only record the outer most thread extent. + if (op->attr_key == attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + VisitNewScope(op); + in_thread_env_ = false; + } else if (op->attr_key == attr::extern_scope) { + VisitNewScope(op); + } else if (op->attr_key == attr::virtual_thread) { + VisitNewScope(op); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } - std::unordered_set dyn_shmem_allocs_; + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + + private: + // Whether already in thread env. + bool in_thread_env_{false}; + // The scope stack. + std::vector scope_; }; +/*! + * \brief merge the buffers whose live range has no intersection and rewrite the body + */ class DynamicSharedMemoryRewriter : public StmtExprMutator { public: explicit DynamicSharedMemoryRewriter( - const std::unordered_set& dyn_shmem_allocs) + const std::unordered_map& dyn_shmem_allocs) : dyn_shmem_allocs_{dyn_shmem_allocs} {} + /*! + * \brief plan the memory reuse for all the buffer allocated in the statement + * \param stmt the statement + */ + void PlanReuse(const Stmt& stmt) { + DynSharedMemLinearAccessPatternFinder finder; + finder(stmt); + this->LivenessAnalysis(finder.linear_seq_); + this->PlanMemory(finder.linear_seq_); + } + + private: Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent && !allocated) { + if (op->attr_key == attr::thread_extent && !allocated_) { // Allocate one dynamic shared memory allocation at the beginning of thread scope - int align = 1; - for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; - align = std::max(align, alloc->dtype.bytes()); + int max_layer_num = 0; + std::vector all_entry; + for (const auto& e : const_free_map_) { + all_entry.push_back(e.second); + } + for (const StorageEntry* e : sym_free_list_) { + all_entry.push_back(e); + } + for (const StorageEntry* e : all_entry) { + max_layer_num = std::max(max_layer_num, static_cast(e->allocs.size())); + } + // calculate align for each layer of each storage entry. + std::vector align(max_layer_num, 0); + for (const StorageEntry* e : all_entry) { + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + for (const VarNode* buffer : e->allocs[i]) { + const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + align[i] = std::max(align[i], alloc->dtype.bytes()); + } + } } - for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->extents.size(), 1); - buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; - merged_alloc_size_ += alloc->extents[0] * align; + // calculate offset for each buffer based on the align of each layer + for (const StorageEntry* e : all_entry) { + PrimExpr max_inner_offset = 0; + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + PrimExpr inner_offset = 0; + for (const VarNode* buffer : e->allocs[i]) { + const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; + inner_offset += alloc->extents[0] * alloc->dtype.bytes(); + inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]); + } + max_inner_offset = max(max_inner_offset, inner_offset); + } + merged_alloc_size_ += max_inner_offset; } - allocated = true; - auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, - const_true(), StmtExprMutator::VisitStmt(op->body)); + allocated_ = true; + Allocate new_body(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(), + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } return StmtMutator::VisitStmt_(op); @@ -90,8 +295,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const LoadNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - auto offset = GetBufferOffset(op->buffer_var, op->dtype); - auto index = StmtExprMutator::VisitExpr(op->index); + PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype); + PrimExpr index = StmtExprMutator::VisitExpr(op->index); return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitExpr_(op); @@ -99,33 +304,265 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { Stmt VisitStmt_(const StoreNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); - auto index = StmtExprMutator::VisitExpr(op->index); - auto value = StmtExprMutator::VisitExpr(op->value); + PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype); + PrimExpr index = StmtExprMutator::VisitExpr(op->index); + PrimExpr value = StmtExprMutator::VisitExpr(op->value); return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitStmt_(op); } - private: + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + Var buffer = Downcast(op->args[1]); + if (!IsDynamicSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + return Call(op->dtype, op->op, + {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); ICHECK(it != buffer_byte_offsets_.end()); return indexdiv(it->second, dtype.bytes()); } + using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry; + struct StorageEntry { + // The constant size of the buffer in bits, only used if it is constant + uint64_t const_nbits{0}; + // Allocs that shares this entry. + // The inner vector means a "layer" + // For example, it we need to allocate C in the memory of A and B: + // | A: 4096 bytes | B: 4096 bytes | + // | C: 8192 bytes | + // Then the allocs = {{A, B}, {C}} + std::vector> allocs; + }; + + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + /*! + * \brief Liveness analysis to find gen and kill point of each variable. + * \param seq the linear pattern of storage access + */ + void LivenessAnalysis(const std::vector& seq) { + // find kill point, do a reverse linear scan. + std::unordered_set touched; + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry& s = seq[i - 1]; + for (const VarNode* buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) continue; + const StmtEntry& s = seq[i + offset]; + for (const VarNode* buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); + } + } + } + } + + /*! + * \brief Memory plan algorithm + * \param seq the linear pattern of storage access + * \param alloc_info + */ + void PlanMemory(const std::vector& seq) { + std::unordered_set inplace_flag; + + for (size_t i = 0; i < seq.size(); ++i) { + auto it = event_map_.find(seq[i].stmt); + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode* var : it->second.kill) { + this->Free(var); + } + } + // scope_pair_offset >= 0 means it is either + // - leaf stmt(offset = 0) + // - beginning of scope(offset < 0) + // In both cases, we need to handle the gen event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + for (const VarNode* var : it->second.gen) { + ICHECK(dyn_shmem_allocs_.count(var)); + const AllocateNode* alloc = dyn_shmem_allocs_[var]; + StorageEntry* dst_entry = FindAlloc(alloc); + alloc_map_[var] = dst_entry; + } + } + } + } + /*! + * \brief Allocate new storage entry. + * \param op the allocate node + * \param the size of the allocation in bits + * \return the new storage entry + */ + StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) { + ICHECK(op != nullptr); + // Re-use not successful, allocate a new buffer. + StorageEntry* entry = arena_.make(); + entry->allocs.push_back({op->buffer_var.get()}); + entry->const_nbits = const_nbits; + return entry; + } + /*! + * \brief find the storage entry in the free list for the allocate + * \param op the allocate node + * \return the storage entry + */ + StorageEntry* FindAlloc(const AllocateNode* op) { + ICHECK(op != nullptr); + // skip plan for local variable, + // compiler can do a better job with register allocation. + const uint64_t match_range = 16; + uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); + uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (const_nbits > 0 && const_nbits <= 32) { + return NewAlloc(op, const_nbits); + } + + if (const_nbits != 0) { + // constant allocation. + auto begin = const_free_map_.lower_bound(0); + auto mid = const_free_map_.lower_bound(const_nbits); + auto end = const_free_map_.upper_bound(const_nbits * match_range); + // Start looking at the buffer that is bigger than the required size first. + // If we find one, directly allocate the buffer in its location and remove its entry in the + // free list + for (auto it = mid; it != end; ++it) { + StorageEntry* e = it->second; + e->const_nbits = std::max(const_nbits, e->const_nbits); + const_free_map_.erase(it); + return e; + } + // Then start looking at smaller buffers. + // Keep collecting the buffer until the sum of their size exceeds the buffer to allocate + // and finally free all these entry in the free list + std::vector::iterator> delete_it; + // the alloc list for the new entry + std::vector> reuse_allocs; + uint64_t mem_ct = 0; + for (auto it = mid; it != begin;) { + --it; + delete_it.push_back(it); + mem_ct += it->second->const_nbits; + int n = it->second->allocs.size(); + if (n > static_cast(reuse_allocs.size())) { + reuse_allocs.resize(n, {}); + } + for (int i = 0; i < n; i++) { + for (const VarNode* alloc : it->second->allocs[i]) { + reuse_allocs[i].push_back(alloc); + } + } + if (mem_ct >= const_nbits) { + break; + } + } + reuse_allocs.push_back({op->buffer_var.get()}); + if (mem_ct != 0) { + StorageEntry* e = arena_.make(); + e->const_nbits = std::max(const_nbits, mem_ct); + e->allocs = reuse_allocs; + for (auto it : delete_it) { + const_free_map_.erase(it); + } + return e; + } + } else { + // if its symbolic allocation, just arbitrarily choose one entry to fit in because we don't + // know its actual size + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { + StorageEntry* e = *it; + sym_free_list_.erase(it); + return e; + } + } + return NewAlloc(op, const_nbits); + } + + /*! + * \brief add the storage entry to the buffer var into the free list. + * \param var the buffer var + */ + void Free(const VarNode* var) { + auto it = alloc_map_.find(var); + ICHECK(it != alloc_map_.end()); + StorageEntry* e = it->second; + ICHECK_NE(e->allocs.size(), 0U); + + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) return; + + // normal free. + if (e->const_nbits != 0) { + const_free_map_.insert({e->const_nbits, e}); + } else { + sym_free_list_.push_back(e); + } + } + // The var for the merged buffer Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; - std::unordered_set dyn_shmem_allocs_; + // The mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The size of the merged buffer PrimExpr merged_alloc_size_{0}; + // The mapping from the original buffer var to its offset in the merged buffer std::unordered_map buffer_byte_offsets_; - bool allocated{false}; + // The flag indicating whether the merged buffer has been allocated + bool allocated_{false}; + // Locations of free ops. + std::unordered_map event_map_; + // constant size free map. + std::multimap const_free_map_; + // symbolic free list, for non constant items. + std::list sym_free_list_; + // The allocation assign map + std::unordered_map alloc_map_; + /*! \brief allocator of all the StorageEntry*/ + support::Arena arena_; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { AllocateCollector collector; collector(stmt); if (collector.dyn_shmem_allocs_.size() > 1) { - return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); + rewriter.PlanReuse(stmt); + return rewriter(std::move(stmt)); } return stmt; } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 795ae9d6a73a..7f2ecf54dfcb 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -278,6 +278,7 @@ class HostDeviceSplitter : public StmtMutator { WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); + device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); if (m.use_dyn_shmem_) { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceUseDynSharedMemory, Integer(1)); diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 6a26103e6079..aa586846f5d4 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -26,11 +26,14 @@ #include #include +#include "../../support/utils.h" #include "ir_utils.h" namespace tvm { namespace tir { +using support::StartsWith; + /*! * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same @@ -41,14 +44,28 @@ class ThreadBindingUnifier : public StmtExprMutator { static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); } private: - Stmt VisitStmt_(const AttrStmtNode* attr) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { // If this AttrStmt is not thread binding attribute, return as usual. - if (attr->attr_key != attr::thread_extent && attr->attr_key != attr::virtual_thread) { - return StmtMutator::VisitStmt_(attr); + if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) { + return StmtMutator::VisitStmt_(op); + } + IterVar old_iter_var = Downcast(op->node); + return UnifyThreadBindingImpl(op, old_iter_var->var, old_iter_var, old_iter_var->dom); + } + + Stmt VisitStmt_(const ForNode* op) final { + // If this For is not thread binding attribute, return as usual. + if (op->kind != ForKind::kThreadBinding) { + return StmtExprMutator::VisitStmt_(op); } + return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), + Range::FromMinExtent(op->min, op->extent)); + } - // Step 1. Fetch the old IterVar and the thread tag. - IterVar old_iter_var = Downcast(attr->node); + template + Stmt UnifyThreadBindingImpl(const Node* op, const Var& old_var, const IterVar& old_iter_var, + const Range& dom) { + // Step 1. Fetch the thread tag. IterVar new_iter_var{nullptr}; const String& thread_tag = old_iter_var->thread_tag; @@ -56,9 +73,12 @@ class ThreadBindingUnifier : public StmtExprMutator { // thread block depth is 0 before the increasement, it means we are entering a new kernel, and // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have // thread axes with different extents. - if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { + bool is_kernel_launch_scope = false; + int old_thread_block_depth = thread_block_depth_; + if (StartsWith(thread_tag, "blockIdx.") || !thread_block_depth_) { if (!thread_block_depth_) { thread_tag2iter_var_map_.clear(); + is_kernel_launch_scope = true; } ++thread_block_depth_; } @@ -69,31 +89,56 @@ class ThreadBindingUnifier : public StmtExprMutator { Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; - CHECK(ana.CanProveEqual(old_iter_var->dom->extent, (*it).second->dom->extent)) + ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); + CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) << "ValueError: All loops that are bound to `" << thread_tag << "` should have the same extent. However, there are two loops with extent " - << (*it).second->dom->extent << " and " << old_iter_var->dom->extent - << ", which are not equal"; + << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal"; } else { ObjectPtr p_new_iter_var = make_object(*old_iter_var.get()); p_new_iter_var->var = Var(thread_tag); + p_new_iter_var->dom = dom; new_iter_var = IterVar(p_new_iter_var); thread_tag2iter_var_map_.Set(thread_tag, new_iter_var); + launch_threads_.push_back(new_iter_var); } // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the // new variable in further mutation. Thus, we store the mapping entry. - var_substitution_map_.Set(old_iter_var->var, new_iter_var->var); - - // Step 5. Mutate recursively, update the AttrStmt with the new IterVar, and decrease the depth - // counter if the thread tag starts with "blockIdx". - AttrStmt new_attr = Downcast(StmtMutator::VisitStmt_(attr)); - ObjectPtr p_new_attr = CopyOnWrite(new_attr.get()); - p_new_attr->node = new_iter_var; - if (std::string(thread_tag).substr(0, 9) == "blockIdx.") { - --thread_block_depth_; + var_substitution_map_.Set(old_var, new_iter_var->var); + + // Step 5. Mutate recursively, update the body with the new IterVar, and restore the depth + // counter. Emit for-loops to launch threads if current statement is the outermost thread + // binding of the kernel. + Stmt new_stmt = StmtMutator::VisitStmt_(op); + auto* new_node = new_stmt.as(); + ICHECK(new_node); + thread_block_depth_ = old_thread_block_depth; + if (is_kernel_launch_scope) { + return EmitLaunchThreads(new_node->body); + } else { + return new_node->body; } - return Stmt(p_new_attr); + } + + /*! + * \brief Emit loop nests representing all thread bindings of the kernel + * \param body The body of the innermost loop of the thread bindings. + * \return The loop nests of the thread bindings. + */ + Stmt EmitLaunchThreads(const Stmt& body) { + Stmt result = body; + while (!launch_threads_.empty()) { + const IterVar& thread_binding = launch_threads_.back(); + // Recreate the IterVar as we don't duplicate `dom` in both For and IterVar. This is + // necessary for unit tests. + result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, + ForKind::kThreadBinding, result, + IterVar(NullValue(), Var(""), IterVarType::kThreadIndex, + thread_binding->thread_tag)); + launch_threads_.pop_back(); + } + return result; } PrimExpr VisitExpr_(const VarNode* var) final { @@ -106,8 +151,13 @@ class ThreadBindingUnifier : public StmtExprMutator { /*! * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all * occurrences of the thread tag - * */ + */ Map thread_tag2iter_var_map_; + /*! + * \brief A list of IterVar corresponding to threads in current kernel. This will be used to + * generate for-loops to launch threads. + */ + Array launch_threads_; /*! \brief A mapping from old variables to new variables, which is used for substitution */ Map var_substitution_map_; /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ diff --git a/tests/cpp/relay/backend/executor_test.cc b/tests/cpp/relay/backend/executor_test.cc new file mode 100644 index 000000000000..3367390b27f2 --- /dev/null +++ b/tests/cpp/relay/backend/executor_test.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +namespace tvm { +namespace relay { + +TVM_REGISTER_EXECUTOR("TestExecutor") + .add_attr_option("my_bool") + .add_attr_option>("your_names") + .add_attr_option("another_option") + .add_attr_option("defaulty_the_default_option", Bool(false)); + +TEST(Executor, Create) { + Map attrs = {{"my_bool", Bool(true)}}; + Executor my_exec = Executor::Create("TestExecutor", attrs); + ASSERT_EQ(my_exec->GetAttr("my_bool"), true); + ASSERT_EQ(my_exec->GetAttr>("your_names").defined(), false); + ASSERT_EQ(my_exec->GetAttr("defaulty_the_default_option"), false); +} + +TEST(Executor, UnknownAttr) { + Map attrs = {{"woofles", Bool(true)}}; + ASSERT_THROW(Executor::Create("TestExecutor", attrs), Error); +} + +TEST(Executor, IncorrectAttrType) { + Map attrs = {{"my_bool", String("snuck_in")}}; + ASSERT_THROW(Executor::Create("TestExecutor", attrs), Error); +} + +TEST(Executor, UnregisteredName) { + Map attrs = {}; + ASSERT_THROW(Executor::Create("NeverNameAnExecutorThis", attrs), Error); +} + +TEST(ExecutorRegistry, ListExecutors) { + Array names = Executor::ListExecutors(); + ICHECK_EQ(names.empty(), false); + ICHECK_EQ(std::count(std::begin(names), std::end(names), "TestExecutor"), 1); +} + +TEST(ExecutorRegistry, ListExecutorOptions) { + Map attrs = Executor::ListExecutorOptions("TestExecutor"); + + ICHECK_EQ(attrs.empty(), false); + ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["your_names"], "Array"); + ICHECK_EQ(attrs["another_option"], "runtime.String"); +} + +TEST(ExecutorRegistry, ListExecutorOptionsNoExecutor) { + ASSERT_THROW(Executor::ListExecutorOptions("NeverNameAnExecutorThis"), Error); +} + +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc new file mode 100644 index 000000000000..53ea7e39ed59 --- /dev/null +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +namespace tvm { +namespace relay { + +TVM_REGISTER_RUNTIME("TestRuntime") + .add_attr_option("my_bool") + .add_attr_option>("your_names") + .add_attr_option("another_option") + .add_attr_option("defaulty_the_default_option", Bool(false)); + +TEST(Runtime, Create) { + Map attrs = {{"my_bool", Bool(true)}}; + Runtime my_runtime = Runtime::Create("TestRuntime", attrs); + ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); + ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); + ASSERT_EQ(my_runtime->GetAttr("defaulty_the_default_option"), false); +} + +TEST(Runtime, UnknownAttr) { + Map attrs = {{"woofles", Bool(true)}}; + ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); +} + +TEST(Runtime, IncorrectAttrType) { + Map attrs = {{"my_bool", String("snuck_in")}}; + ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); +} + +TEST(Runtime, UnregisteredName) { + Map attrs = {}; + ASSERT_THROW(Runtime::Create("NeverNameAnRuntimeThis", attrs), Error); +} + +TEST(RuntimeRegistry, ListRuntimes) { + Array names = Runtime::ListRuntimes(); + ICHECK_EQ(names.empty(), false); + ICHECK_EQ(std::count(std::begin(names), std::end(names), "TestRuntime"), 1); +} + +TEST(RuntimeRegistry, ListRuntimeOptions) { + Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); + + ICHECK_EQ(attrs.empty(), false); + ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["your_names"], "Array"); + ICHECK_EQ(attrs["another_option"], "runtime.String"); +} + +TEST(RuntimeRegistry, ListRuntimeOptionsNoRuntime) { + ASSERT_THROW(Runtime::ListRuntimeOptions("NeverNameAnRuntimeThis"), Error); +} + +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc index 37b44524e770..ae95185cb287 100644 --- a/tests/cpp/relay_dismantler_test.cc +++ b/tests/cpp/relay_dismantler_test.cc @@ -143,3 +143,22 @@ TEST(Relay, TupleiGetItemSharedTuple) { .as() ->args.size()); } + +TEST(Relay, OutOfStackLet) { + auto foo = [] { + auto add_op = relay::Op::Get("add"); + auto p = relay::Var("p", relay::TensorType({3, 2}, DataType::Float(32))); + int size = 1e6 - 1; + std::vector vars; + for (int i = 0; i < size; ++i) { + vars.emplace_back("x_" + std::to_string(i), relay::TensorType({3, 2}, DataType::Float(32))); + } + Expr body = vars[size - 1]; + for (int i = size - 1; i >= 0; --i) { + Var v = i == 0 ? p : vars[i - 1]; + body = relay::Let(vars[i], relay::Call(add_op, {v, v}), body); + } + relay::Function func = relay::Function({p}, body, relay::Type(), {}); + }; + ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); +} diff --git a/tests/cpp/support_test.cc b/tests/cpp/support_test.cc index df9271f4b49c..01111d910246 100644 --- a/tests/cpp/support_test.cc +++ b/tests/cpp/support_test.cc @@ -56,5 +56,11 @@ TEST(HashTests, HashStability) { EXPECT_EQ(::tvm::support::HashCombine(e, f), 2722928432); } +TEST(StartsWithTests, Basic) { + EXPECT_TRUE(::tvm::support::StartsWith("abc", "abc")); + EXPECT_TRUE(::tvm::support::StartsWith("abcd", "abc")); + EXPECT_FALSE(::tvm::support::StartsWith("abc", "abcd")); +} + } // namespace test } // namespace tvm diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc new file mode 100644 index 000000000000..ae5f5d0c3dc4 --- /dev/null +++ b/tests/cpp/target/compilation_config_test.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +Target TestCpuTarget() { return Target("llvm -mcpu arm64"); } + +Target TestCudaTarget() { return Target("nvidia/tesla-p40"); } + +Target TestDefaultCpuTarget() { return Target("llvm"); } + +Target TestExtDevTarget() { return Target("ext_dev"); } + +CompilationConfig TestCompilationConfig() { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLCPU)), TestCpuTarget()); + return CompilationConfig(pass_ctx, legacy_target_map, TestDefaultCpuTarget()); +} + +TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 1); + EXPECT_TRUE(StructuralEqual()((*config->legacy_target_map.begin()).second, + Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 1); + EXPECT_TRUE( + StructuralEqual()(config->primitive_targets[0], Target::WithHost(cuda_target, host_target))); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cuda_target, host_target))); +} + +TEST(CompilationConfig, Constructor_Homegenoous_InnerHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestCpuTarget(); + Target cuda_target = Target::WithHost(TestCudaTarget(), host_target); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); +} + +TEST(CompilationConfig, Constructor_Homogenous_CPUHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, cpu_target)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cpu_target, cpu_target))); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); + Target host_target = TestCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, host_target); + + SEScope expected_default_primitive_se_scope(kDLCUDA, 0, + Target::WithHost(cuda_target, host_target)); + SEScope expected_host_se_scope(kDLCPU, 0, host_target); + + ASSERT_EQ(config->legacy_target_map.size(), 2); + EXPECT_TRUE(config->host_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); + ASSERT_EQ(config->primitive_targets.size(), 2); + EXPECT_TRUE( + StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_FALSE(config->optional_homogeneous_target.defined()); +} + +TEST(CompilationConfig, Constructor_InvalidAttribute) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kInvalidDeviceType))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); +} + +TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLMetal))); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); +} + +TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); + legacy_target_map.Set(Integer(static_cast(kDLExtDev)), TestExtDevTarget()); + EXPECT_ANY_THROW( + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); +} + +TEST(CompilationConfig, CanonicalSEScope) { + Target host_target = TestDefaultCpuTarget(); + Target cuda_target = TestCudaTarget(); + Target cpu_target = TestCpuTarget(); + CompilationConfig config = TestCompilationConfig(); + + { + SEScope in = SEScope(kDLCPU); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cpu_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } + { + SEScope in = SEScope(kDLCUDA); + SEScope actual = config->CanonicalSEScope(in); + ASSERT_TRUE(actual->target.defined()); + EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); + EXPECT_EQ(config->CanonicalSEScope(in), actual); + } +} + +TEST(CompilationConfig, CanonicalSEScope_NoDevice) { + CompilationConfig config = TestCompilationConfig(); + SEScope fully_unconstrained; + EXPECT_ANY_THROW(config->CanonicalSEScope(fully_unconstrained)); + SEScope missing_device(kInvalidDeviceType, 3, {}, "local"); + EXPECT_ANY_THROW(config->CanonicalSEScope(missing_device)); +} + +TEST(CompilationConfig, CanonicalSEScope_NoMatchingTarget) { + CompilationConfig config = TestCompilationConfig(); + SEScope no_such_target(kDLMetal); + EXPECT_ANY_THROW(config->CanonicalSEScope(no_such_target)); +} + +} // namespace +} // namespace tvm diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc new file mode 100644 index 000000000000..166ba46faf37 --- /dev/null +++ b/tests/cpp/target/se_scope_test.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +TEST(SEScope, Join_Defined) { + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, -1, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, target_a, "global"); + SEScope rhs = SEScope(kDLCUDA, 3); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCUDA, 2, target_a); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = SEScope(kDLCUDA, 2, target_a); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + SEScope expected = rhs; + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } +} + +TEST(SEScope, Join_Undefined) { + { + SEScope lhs = SEScope(kDLCUDA); + SEScope rhs = SEScope(kDLCPU); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3); + SEScope rhs = SEScope(kDLCUDA, 4); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda")); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda")); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"), "local"); + SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"), "global"); + Optional actual = SEScope::Join(lhs, rhs); + EXPECT_FALSE(actual); + } +} + +TEST(SEScope, Default) { + Target target_a = Target("cuda"); + SEScope lhs = SEScope(kDLCUDA, -1, Target(), "global"); + SEScope rhs = SEScope(kDLCUDA, 3, target_a, "local"); + SEScope actual = SEScope::Default(lhs, rhs); + SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual, expected)); +} + +TEST(SEScope, Constructor_Invalid) { EXPECT_ANY_THROW(SEScope(kDLCPU, -1, Target("cuda"))); } + +TEST(SEScopeCache, Memoized) { + SEScopeCache cache; + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + SEScope se_scope_a = cache.Make(kDLCUDA, 3, target_a, "local"); + SEScope se_scope_b = cache.Make(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), se_scope_a); + EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), se_scope_b); + EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), se_scope_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), se_scope_a); +} + +} // namespace +} // namespace tvm diff --git a/tests/lint/pylint.sh b/tests/lint/pylint.sh index a96c2672c01f..e47d576ced2f 100755 --- a/tests/lint/pylint.sh +++ b/tests/lint/pylint.sh @@ -19,3 +19,4 @@ python3 -m pylint python/tvm --rcfile=$(dirname "$0")/pylintrc python3 -m pylint vta/python/vta --rcfile=$(dirname "$0")/pylintrc +python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile=$(dirname "$0")/pylintrc diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index 73361774821b..8625b4a45364 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -24,16 +24,10 @@ from tvm.micro import project from tvm import micro, relay -TEMPLATE_PROJECT_DIR = ( - pathlib.Path(__file__).parent - / ".." - / ".." - / ".." - / "apps" - / "microtvm" - / "arduino" - / "template_project" -).resolve() +TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino")) + + +BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" diff --git a/tests/micro/zephyr/test_utils.py b/tests/micro/zephyr/test_utils.py index c27c869509d7..e4a22d2be647 100644 --- a/tests/micro/zephyr/test_utils.py +++ b/tests/micro/zephyr/test_utils.py @@ -18,8 +18,9 @@ import os import json import pathlib -import logging import tarfile +import tempfile +from typing import Union import numpy as np @@ -29,18 +30,10 @@ import requests import tvm.micro +from tvm.micro import export_model_library_format +from tvm.micro.testing import mlf_extract_workspace_size_bytes - -TEMPLATE_PROJECT_DIR = ( - pathlib.Path(__file__).parent - / ".." - / ".." - / ".." - / "apps" - / "microtvm" - / "zephyr" - / "template_project" -).resolve() +TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr")) BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" @@ -77,19 +70,29 @@ def has_fpu(board: str): def build_project(temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None): project_dir = temp_dir / "project" - project = tvm.micro.generate_project( - str(TEMPLATE_PROJECT_DIR), - mod, - project_dir, - { - "extra_files_tar": extra_files_tar, - "project_type": "aot_demo", - "west_cmd": west_cmd, - "verbose": bool(build_config.get("debug")), - "zephyr_board": zephyr_board, - }, - ) - project.build() + + with tempfile.TemporaryDirectory() as tar_temp_dir: + model_tar_path = pathlib.Path(tar_temp_dir) / "model.tar" + export_model_library_format(mod, model_tar_path) + + workspace_size = mlf_extract_workspace_size_bytes(model_tar_path) + project = tvm.micro.project.generate_project_from_mlf( + str(TEMPLATE_PROJECT_DIR), + project_dir, + model_tar_path, + { + "extra_files_tar": extra_files_tar, + "project_type": "aot_demo", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + "compile_definitions": [ + # TODO(mehrdadh): It fails without offset. + f"-DWORKSPACE_SIZE={workspace_size + 128}", + ], + }, + ) + project.build() return project, project_dir @@ -129,31 +132,6 @@ def create_header_file(tensor_name, npy_data, output_path, tar_file): tar_file.addfile(ti, io.BytesIO(header_file_bytes)) -def _read_line(fd, timeout_sec: int): - data = "" - new_line = False - while True: - if new_line: - break - new_data = fd.read(1, timeout_sec=timeout_sec) - logging.debug(f"read data: {new_data}") - for item in new_data: - new_c = chr(item) - data = data + new_c - if new_c == "\n": - new_line = True - break - return data - - -def get_message(fd, expr: str, timeout_sec: int): - while True: - data = _read_line(fd, timeout_sec) - logging.debug(f"new line: {data}") - if expr in data: - return data - - # TODO move CMSIS integration to microtvm_api_server.py # see https://discuss.tvm.apache.org/t/tvm-capturing-dependent-libraries-of-code-generated-tir-initially-for-use-in-model-library-format/11080 def loadCMSIS(temp_dir): diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index be1f231156ad..10759c3790db 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -374,6 +374,9 @@ def test_tensors(sess): @tvm.testing.requires_micro def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): """Test AutoTune for microTVM Zephyr""" + if board != "qemu_x86": + pytest.xfail(f"Autotune fails on {board}.") + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index 5bc665b748f6..7cd32f4e1879 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -33,6 +33,7 @@ from tvm.contrib.download import download_testdata from tvm.micro.model_library_format import generate_c_interface_header +from tvm.micro.testing import aot_transport_init_wait, aot_transport_find_message import test_utils @@ -40,23 +41,13 @@ @tvm.testing.requires_micro def test_tflite(temp_dir, board, west_cmd, tvm_debug): """Testing a TFLite model.""" - - if board not in [ - "qemu_x86", - "mps2_an521", - "nrf5340dk_nrf5340_cpuapp", - "nucleo_l4r5zi", - "qemu_cortex_r5", - ]: - pytest.skip(msg="Model does not fit.") - model = test_utils.ZEPHYR_BOARDS[board] - input_shape = (1, 32, 32, 3) - output_shape = (1, 10) + input_shape = (1, 49, 10, 1) + output_shape = (1, 12) build_config = {"debug": tvm_debug} - model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite" - model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model") + model_url = "https://github.com/tlc-pack/web-data/raw/25fe99fb00329a26bd37d3dca723da94316fd34c/testdata/microTVM/model/keyword_spotting_quant.tflite" + model_path = download_testdata(model_url, "keyword_spotting_quant.tflite", module="model") # Import TFLite model tflite_model_buf = open(model_path, "rb").read() @@ -71,20 +62,25 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): # Load TFLite model and convert to Relay relay_mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "float32"} + tflite_model, shape_dict={"input_1": input_shape}, dtype_dict={"input_1 ": "int8"} ) target = tvm.target.target.micro( - model, options=["-link-params=1", "--executor=aot", "--unpacked-api=1", "--interface-api=c"] + model, + options=[ + "-link-params=1", + "--executor=aot", + "--unpacked-api=1", + "--interface-api=c", + "--workspace-byte-alignment=4", + ], ) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): lowered = relay.build(relay_mod, target, params=params) # Load sample and generate input/output header files - sample_url = "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/testdata_image_classification_fp32_8.npy" - sample_path = download_testdata( - sample_url, "testdata_image_classification_fp32_8.npy", module="data" - ) + sample_url = "https://github.com/tlc-pack/web-data/raw/967fc387dadb272c5a7f8c3461d34c060100dbf1/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy" + sample_path = download_testdata(sample_url, "keyword_spotting_int8_6.pyc.npy", module="data") sample = np.load(sample_path) with tempfile.NamedTemporaryFile() as tar_temp_file: @@ -99,7 +95,7 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): test_utils.create_header_file("input_data", sample, "include", tf) test_utils.create_header_file( - "output_data", np.zeros(shape=output_shape, dtype="float32"), "include", tf + "output_data", np.zeros(shape=output_shape, dtype="int8"), "include", tf ) project, _ = test_utils.build_project( @@ -113,17 +109,16 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): project.flash() with project.transport() as transport: - timeout_read = 60 - test_utils.get_message(transport, "#wakeup", timeout_sec=timeout_read) - transport.write(b"start\n", timeout_sec=5) - result_line = test_utils.get_message(transport, "#result", timeout_sec=timeout_read) + aot_transport_init_wait(transport) + transport.write(b"infer%", timeout_sec=5) + result_line = aot_transport_find_message(transport, "result", timeout_sec=60) result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) time = int(result_line[2]) logging.info(f"Result: {result}\ttime: {time} ms") - assert result == 8 + assert result == 6 @tvm.testing.requires_micro diff --git a/tests/micro/zephyr/test_zephyr_armv7m.py b/tests/micro/zephyr/test_zephyr_armv7m.py index 972ffe2bda35..2366bad203be 100644 --- a/tests/micro/zephyr/test_zephyr_armv7m.py +++ b/tests/micro/zephyr/test_zephyr_armv7m.py @@ -25,8 +25,6 @@ import pytest import numpy as np -import test_utils - import tvm import tvm.rpc import tvm.micro @@ -35,18 +33,17 @@ from tvm.contrib.download import download_testdata from tvm.micro.model_library_format import generate_c_interface_header +from tvm.micro.testing import aot_transport_init_wait, aot_transport_find_message -import conftest - +import test_utils _LOG = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) def _open_tflite_model(): # Import TFLite model - model_url = "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/model/mnist_model_quant.tflite" + model_url = "https://github.com/tlc-pack/web-data/raw/b2f3c02427b67267a00fd968ba1fce28fc833028/testdata/microTVM/model/mnist_model_quant.tflite" model_path = download_testdata(model_url, "mnist_model_quant.tflite", module="model") tflite_model_buf = open(model_path, "rb").read() @@ -145,15 +142,15 @@ def _run_model(temp_dir, board, west_cmd, lowered, build_config, sample, output_ project.flash() with project.transport() as transport: - timeout_read = 60 - transport.write(b"start\n", timeout_sec=5) - result_line = test_utils.get_message(transport, "#result", timeout_sec=timeout_read) + aot_transport_init_wait(transport) + transport.write(b"infer%", timeout_sec=5) + result_line = aot_transport_find_message(transport, "result", timeout_sec=60) result_line = result_line.strip("\n") result_line = result_line.split(":") result = int(result_line[1]) time = int(result_line[2]) - logging.info(f"Result: {result}\ttime: {time} ms") + _LOG.info(f"Result: {result}\ttime: {time} ms") return result, time @@ -186,6 +183,17 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug): relay_mod_no_simd = _apply_desired_layout_no_simd(relay_mod) target = tvm.target.target.micro( + model, + options=[ + "-keys=cpu", + "-link-params=1", + "--executor=aot", + "--unpacked-api=1", + "--interface-api=c", + ], + ) + + target_simd = tvm.target.target.micro( model, options=[ "-keys=arm_cpu,cpu", @@ -203,7 +211,7 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug): os.makedirs(temp_dir_no_simd, exist_ok=True) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lowered_simd = relay.build(relay_mod_simd, target, params=params) + lowered_simd = relay.build(relay_mod_simd, target_simd, params=params) lowered_no_simd = relay.build(relay_mod_no_simd, target, params=params) result_simd, time_simd = _run_model( temp_dir_simd, board, west_cmd, lowered_simd, build_config, sample, output_shape diff --git a/tests/python/conftest.py b/tests/python/conftest.py index e8042c8f5095..ab3ea4e4ec06 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -37,6 +37,3 @@ # collect_ignore.append("unittest/test_auto_scheduler_measure.py") # exception ignored collect_ignore.append("unittest/test_tir_intrin.py") - -if tvm.support.libinfo().get("USE_MICRO", "OFF") != "ON": - collect_ignore.append("unittest/test_micro_transport.py") diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index f151a85ec5b1..e582874d1de2 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -184,7 +184,7 @@ def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, acl_parti ), "Got {} Arm Compute Library partitions, expected {}".format( partition_count, acl_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, params=params) diff --git a/tests/python/contrib/test_bnns/infrastructure.py b/tests/python/contrib/test_bnns/infrastructure.py index 46bd049402a9..5a12b0487408 100644 --- a/tests/python/contrib/test_bnns/infrastructure.py +++ b/tests/python/contrib/test_bnns/infrastructure.py @@ -142,7 +142,7 @@ def build_module(mod, target, params=None, enable_bnns=True, tvm_ops=0): with tvm.transform.PassContext(opt_level=3): if enable_bnns: mod = partition_for_bnns(mod) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target=target, target_host=target, params=params) diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index d785cfa199ae..42eb31a3532c 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -103,7 +103,7 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z assert any(attrs), "At least one function with external attributes was expected." compilers = [ - key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() ] assert any(compilers), "Module does not contain function for cmsisnn target." diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index b030437252dc..40e12fc962b2 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -85,7 +85,7 @@ def test_softmax_int8(zero_point, scale): assert any(attrs), "At least one function with external attributes was expected." compilers = [ - key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() ] assert any(compilers), "Module does not contain function for cmsisnn target." diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py new file mode 100644 index 000000000000..5a1ff8b2c17d --- /dev/null +++ b/tests/python/contrib/test_cutlass.py @@ -0,0 +1,293 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +import math +import pytest +import tvm +from tvm import relay +import numpy as np +from tvm.runtime.vm import VirtualMachine +from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from tvm.contrib.cutlass import ( + tune_cutlass_kernels, + build_cutlass_kernels, + build_cutlass_kernels_vm, +) + +logging.basicConfig(level=logging.INFO) + + +def has_cublas(): + return tvm.get_global_func("tvm.contrib.cublas.matmul", True) != None + + +def has_cutlass(): + return tvm.get_global_func("relay.ext.cutlass", True) != None + + +def get_ref_rt_mod(mod, params, target="cuda"): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) + dev = tvm.device(target, 0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + return rt_mod, dev + + +def get_ref_vm(mod, params, target="cuda"): + with tvm.transform.PassContext(opt_level=3): + vm_exec = relay.vm.compile(mod, target=target, params=params) + code, lib = vm_exec.save() + dev = tvm.device(target, 0) + vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib) + return VirtualMachine(vm_exec, dev), dev + + +def get_output(rt_mod, names, inputs): + for name, inp in zip(names, inputs): + rt_mod.set_input(name, inp) + rt_mod.run() + return rt_mod.get_output(0).asnumpy() + + +def get_output_vm(vm, names, inputs): + params = dict(zip(names, inputs)) + return vm.invoke("main", **params).numpy() + + +def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"): + data = relay.var("data", shape=data_shape, dtype="float16") + weight = relay.var("weight", shape=weight_shape, dtype="float16") + return relay.nn.dense(data, weight, out_dtype=out_dtype) + + +def get_dense(M, N, K, out_dtype="float16"): + return get_dense_with_shape((M, K), (N, K), out_dtype) + + +def get_dense_bias(M, N, K, out_dtype="float16"): + dense = get_dense(M, N, K, out_dtype=out_dtype) + bias = relay.var("bias", shape=(N,), dtype=out_dtype) + return relay.nn.bias_add(dense, bias) + + +def get_dense_bias_relu(M, N, K, out_dtype="float16"): + return relay.nn.relu(get_dense_bias(M, N, K, out_dtype="float16")) + + +def get_dense_bias_gelu(M, N, K, out_dtype="float16"): + bias_add = get_dense_bias(M, N, K, out_dtype) + mul = bias_add * relay.const((1.0 / math.sqrt(2.0)), dtype=out_dtype) + if out_dtype == "float16": + erf = relay.cast(relay.op.erf(relay.cast(mul, "float32")), "float16") + else: + erf = relay.op.erf(mul) + mul_half = erf * relay.const(0.5, dtype=out_dtype) + add = mul_half + relay.const(0.5, dtype=out_dtype) + return add * bias_add + + +def get_batch_matmul_with_shape(x_shape, y_shape, out_dtype="float16"): + x = relay.var("x", shape=x_shape, dtype="float16") + y = relay.var("y", shape=y_shape, dtype="float16") + return relay.nn.batch_matmul(x, y, out_dtype=out_dtype) + + +def get_batch_matmul(batch, M, N, K, out_dtype="float16"): + return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") + + +def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): + mod = partition_for_cutlass(mod) + mod, num_cutlass_partition = tune_cutlass_kernels( + mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir + ) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target="cuda", params=params) + lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path) + dev = tvm.device("cuda", 0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + return rt_mod, dev, num_cutlass_partition + + +def profile_and_build_vm( + mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro" +): + mod = partition_for_cutlass(mod) + mod, num_cutlass_partition = tune_cutlass_kernels(mod, sm, tmp_dir=tmp_dir) + with tvm.transform.PassContext(opt_level=3): + vm_exec = relay.vm.compile(mod, target="cuda", params=params) + vm_exec = build_cutlass_kernels_vm(vm_exec, sm, tmp_dir, lib_path, vmcode_path) + dev = tvm.device("cuda", 0) + return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition + + +def verify_dense( + func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False +): + if not has_cutlass(): + return + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod)["main"].body.checked_type + out_dtype = typ.dtype + use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + np_data = np.random.uniform(-1, 1, (M, K)).astype("float16") + np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16") + np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype) + + params = {"weight": np_weight, "bias": np_bias} + + if use_vm: + if ref_target == "cuda" and out_dtype == "float16": + # Uncomment "return" below to see the accuracy difference of static vs dynamic TVM native fp16 dense + # The static one can use a tensorcore schedule, but the dynamic one cannot + rt_mod, dev = get_ref_vm(tvm.IRModule.from_expr(get_dense(M, N, K)), params) + num_partition = 1 + logging.warning( + "The reference fp16 dense with dynamic shape using fp16 accumulation has accuracy issues." + ) + return + else: + rt_mod, dev, num_partition = profile_and_build_vm(mod, params, sm) + + rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target) + x = tvm.nd.array(np_data, device=dev) + out = get_output_vm(rt_mod, ["data"], [x]) + ref_out = get_output_vm(rt_mod_ref, ["data"], [x]) + else: + rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target) + rt_mod, dev, num_partition = profile_and_build(mod, params, sm) + x = tvm.nd.array(np_data, device=dev) + out = get_output(rt_mod, ["data"], [x]) + ref_out = get_output(rt_mod_ref, ["data"], [x]) + + assert num_partition > 0 + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + + if run_benchmark: + print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + +def verify_batch_matmul( + func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False +): + if not has_cutlass(): + return + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod)["main"].body.checked_type + use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16") + y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16") + + if use_vm: + rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm) + rt_mod_ref, dev = get_ref_vm(mod, {}, target=ref_target) + assert num_partition > 0 + x = tvm.nd.array(x_np, device=dev) + y = tvm.nd.array(y_np, device=dev) + out = get_output_vm(rt_mod, ["x", "y"], [x, y]) + ref_out = get_output_vm(rt_mod_ref, ["x", "y"], [x, y]) + else: + rt_mod, dev, num_partition = profile_and_build(mod, {}, sm) + rt_mod_ref, dev = get_ref_rt_mod(mod, {}) + assert num_partition > 0 + + x = tvm.nd.array(x_np, device=dev) + y = tvm.nd.array(y_np, device=dev) + out = get_output(rt_mod, ["x", "y"], [x, y]) + ref_out = get_output(rt_mod_ref, ["x", "y"], [x, y]) + + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + + if True: + print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + +M = 1820 +N = 768 +K = 768 + + +def test_dense(): + verify_dense(get_dense(M, N, K), M, N, K) + verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias(): + verify_dense(get_dense_bias(M, N, K), M, N, K) + verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias_relu(): + verify_dense(get_dense_bias_relu(M, N, K), M, N, K) + verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) + + +def test_dense_bias_gelu(): + verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) + verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) + + +def test_dense_dynamic(): + data_shape = (relay.Any(), K) + weight_shape = (relay.Any(), K) + + if has_cublas(): + # TVM native fp16 dense (without tensorcore), using fp16 accum, seems to have accuracy issues + # Use cublas as a reference + verify_dense( + get_dense_with_shape(data_shape, weight_shape), + M, + N, + K, + ref_target="cuda -libs=cublas", + ) + + verify_dense( + get_dense_with_shape(data_shape, weight_shape, out_dtype="float32"), + M, + N, + K, + atol=1e-4, + rtol=1e-4, + ) + + +def test_batch_matmul(): + batch = 8 + verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K) + verify_batch_matmul(get_batch_matmul(batch, M, N, K, out_dtype="float32"), batch, M, N, K) + + if has_cublas(): + # Test dynamic shape batch_matmul + # AutoTVM does not seem to support it + x_shape = (relay.Any(), relay.Any(), K) + y_shape = (relay.Any(), relay.Any(), K) + + verify_batch_matmul( + get_batch_matmul_with_shape(x_shape, y_shape), + batch, + M, + N, + K, + ref_target="cuda -libs=cublas", + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosn/infrastructure.py b/tests/python/contrib/test_ethosn/infrastructure.py index 92e8f11a2312..f16c37fe19af 100644 --- a/tests/python/contrib/test_ethosn/infrastructure.py +++ b/tests/python/contrib/test_ethosn/infrastructure.py @@ -149,7 +149,7 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): npu_partitions : int, optional The number of Ethos-N partitions expected. """ - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() with tvm.transform.PassContext( opt_level=3, config={"relay.ext.ethos-n.options": {"variant": get_ethosn_variant()}} ): @@ -170,11 +170,19 @@ def build(mod, params, npu=True, expected_host_ops=0, npu_partitions=1): assert ( host_op_count == expected_host_ops ), "Got {} host operators, expected {}".format(host_op_count, expected_host_ops) - partition_count = 0 - for global_var in mod.get_global_vars(): - if "ethos-n" in global_var.name_hint: - partition_count += 1 + attrs = [ + mod[var.name_hint].attrs + for var in mod.get_global_vars() + if mod[var.name_hint].attrs + ] + partition_count = sum( + [ + key == "Compiler" and value == "ethos-n" + for attr in attrs + for key, value in attr.items() + ] + ) assert ( npu_partitions == partition_count ), "Got {} ethos-n partitions, expected {}".format(partition_count, npu_partitions) @@ -254,7 +262,9 @@ def inference_result(outputs): def test_error(mod, params, err_msg): caught = None - with tvm.transform.PassContext(opt_level=3): + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.ethos-n.options": {"variant": get_ethosn_variant()}} + ): with tvm.target.Target("llvm"): try: mod = relay.transform.InferType()(mod) @@ -262,7 +272,7 @@ def test_error(mod, params, err_msg): except tvm.error.TVMError as e: caught = e.args[0] finally: - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() assert caught is not None assert err_msg in caught, caught @@ -324,7 +334,4 @@ def get_ethosn_api_version(): def get_ethosn_variant(): - ethosn_variant_config = os.getenv("ETHOSN_VARIANT_CONFIG") - if ethosn_variant_config is not None: - return "Ethos-N78_1TOPS_2PLE_RATIO" - return "Ethos-N77" + return os.getenv("ETHOSN_VARIANT_CONFIG", default="Ethos-N77") diff --git a/tests/python/contrib/test_ethosn/test_networks.py b/tests/python/contrib/test_ethosn/test_networks.py index f720c55c567a..3a8b95496fde 100644 --- a/tests/python/contrib/test_ethosn/test_networks.py +++ b/tests/python/contrib/test_ethosn/test_networks.py @@ -123,9 +123,9 @@ def test_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"1fd4ef29a1ea9f3a015cab87c0b8014a"} + _compile_hash = {"0433d3c3947a067b36f0228bdb5f1838"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"b879dfbff1f907eaf6129dfd41b44ece"} + _compile_hash = {"e4ed29dceb1187505948ab17fc3cc6d6"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"9c9f63b30824f5b223cdb27d2f22c857"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -150,9 +150,9 @@ def test_inception_v3(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"b90ed315639c6a0e97584c2dbc42a55c"} + _compile_hash = {"43dc2097127eb224c0191b1a15f8acca"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"5693569055695e581a8739194d0301aa"} + _compile_hash = {"7db23387bdc5af6eaa1ae3f7d456caf0"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"46ccafc840633633aca441645e41b444"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -176,9 +176,9 @@ def test_inception_v4(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"b36877d2386d9f9c37a11772e3c4072c"} + _compile_hash = {"fab6c2297502f95d33079c6ce1a737f9"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"b5046a6f56d78af0b4f51960bf2deeda"} + _compile_hash = {"8da68849b75613ac3dffd3fff2dd87da"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"4a1a56393078367dd27915a188d6a6af"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": @@ -202,9 +202,9 @@ def test_ssd_mobilenet_v1(): # codegen, which could come about from either a change in Support Library # version or a change in the Ethos-N codegen. To update this requires running # on hardware that isn't available in CI. - _compile_hash = {"956caf9e7fe5cfd5c042bd17857f7407", "4313033d14328e2aa022b1bd71b27b1c"} + _compile_hash = {"2345cf5d6c0013bad7c76dcccee9d862", "7795b6c67178da9d1f9b98063bad75b1"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": - _compile_hash = {"dc60cc687d892cd2877873094e9dfc0b", "6b3deeec16c24c0dcef23df0db5fb162"} + _compile_hash = {"928dc6ae5ce49a4ad63ca87f7575970f", "b092f9820f7e9341fc53daa781b98772"} if tei.get_ethosn_api_version() == 2011: _compile_hash = {"10826406ae724e52f360a06c35ced09d", "9a484d5ecec7acb18c9d6bc6058be031"} if tei.get_ethosn_variant() == "Ethos-N78_1TOPS_2PLE_RATIO": diff --git a/tests/python/contrib/test_ethosn/test_partition_params.py b/tests/python/contrib/test_ethosn/test_partition_params.py new file mode 100644 index 000000000000..da1750a7e4cb --- /dev/null +++ b/tests/python/contrib/test_ethosn/test_partition_params.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Ethos(TM)-N partition parameter tests""" + +import pytest +import tvm +from tvm import relay +import numpy as np + +from tvm.relay.op.contrib.ethosn import partition_for_ethosn77 +from tvm.relay.op.contrib.ethosn import partition_for_ethosn78 +from tvm.testing import requires_ethosn + + +@requires_ethosn +def test_ethosn78_partition_no_error(): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8") + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + opts = {"variant": "Ethos-N78"} + partition_for_ethosn78(mod, **opts) + + +@requires_ethosn +def test_ethosn78_partition_undefined_variant(): + with pytest.raises( + ValueError, match=r".*When targeting Ethos\(TM\)-N78, -variant=Ethos-N78 should be set.*" + ): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d( + a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8" + ) + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + partition_for_ethosn78(mod) + + +@requires_ethosn +def test_ethosn78_partition_invalid_variant(): + with pytest.raises( + ValueError, match=r".*When targeting Ethos\(TM\)-N78, -variant=Ethos-N78 should be set.*" + ): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d( + a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8" + ) + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + opts = {"variant": "Ethos-N"} + partition_for_ethosn78(mod, **opts) + + +@requires_ethosn +def test_ethosn78_partition_error(): + with pytest.raises( + ValueError, match=r".*When targeting Ethos\(TM\)-N78, -variant=Ethos-N78 should be set.*" + ): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d( + a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8" + ) + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + opts = {"variant": "Ethos-N77"} + partition_for_ethosn78(mod, **opts) + + +@requires_ethosn +def test_ethosn77_partition_no_error(): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d(a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8") + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + partition_for_ethosn77(mod) + + +@requires_ethosn +def test_ethosn77_partition_error(): + with pytest.raises( + ValueError, + match=r".*Setting tops, ple_ratio or sram_size has no effect when targeting Ethos\(TM\)-N77.*", + ): + a = relay.var("a", shape=[2, 7, 8, 8], dtype="uint8") + w = relay.const(np.random.uniform(-10, 10, (8, 7, 3, 3)).astype("uint8")) + res = relay.nn.conv2d( + a, w, kernel_size=(3, 3), padding=(1, 1), channels=8, out_dtype="uint8" + ) + b = relay.var("b", shape=[8], dtype="uint8") + res = relay.nn.bias_add(res, b, axis=1) + + mod = tvm.IRModule.from_expr(res) + opts = {"tops": 4} + partition_for_ethosn77(mod, **opts) diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 01a7ceb9ed56..17d3fad9cb30 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -17,7 +17,6 @@ """ This module provides infrastructure to verify the correctness of the command stream produced. - Currently it will invoke vela to generate a vela-optimized tflite in which the command stream is contained as a custom operator. This class include methods to parse the custom operator to extract @@ -383,14 +382,15 @@ def make_ethosu_conv2d( ifm_layout="NHWC", ofm_layout="NHWC", weight_dtype="int8", + scale_bias_dtype="uint8", ): # conv params weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels) padding = get_pad_tuple(padding, kernel_shape) - scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") - scale_bias = relay.const(scale_bias_data, dtype="uint8") - weight_data = generate_weights_data(weight_shape, "int8") + scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype) + scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype) + weight_data = generate_weights_data(weight_shape, weight_dtype) weight = relay.const(weight_data, dtype=weight_dtype) conv = ethosu_ops.ethosu_conv2d( ifm, @@ -428,13 +428,14 @@ def make_ethosu_depthwise_conv2d( ifm_layout="NHWC", ofm_layout="NHWC", weight_dtype="int8", + scale_bias_dtype="uint8", ): # params weight_shape = (channels, kernel_shape[0], kernel_shape[1], 1) padding = get_pad_tuple(padding, kernel_shape) - scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") - scale_bias = relay.const(scale_bias_data, dtype="uint8") + scale_bias_data = generate_weights_data((weight_shape[0], 10), scale_bias_dtype) + scale_bias = relay.const(scale_bias_data, dtype=scale_bias_dtype) weight_data = generate_weights_data(weight_shape, weight_dtype) weight = relay.const(weight_data, dtype=weight_dtype) depthwise = ethosu_ops.ethosu_depthwise_conv2d( @@ -460,3 +461,104 @@ def make_ethosu_depthwise_conv2d( ofm_layout=ofm_layout, ) return depthwise + + +def get_pooling_args(call, include_buffers=False): + args = call.args + pooling_args = [] + + for i, arg in enumerate(args): + if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + pooling_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + pooling_args.append(arg.index) + else: + pooling_args.append(arg) + + return pooling_args + + +def make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", +): + pooling = ethosu_ops.ethosu_pooling( + ifm, + lut=relay.const([], dtype="int8"), + pooling_type=pooling_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + pool_shape=pool_shape, + ofm_channels=ofm_channels, + strides=strides, + padding=padding, + activation=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return pooling + + +def get_binary_elementwise_args(call, include_buffers=False): + args = call.args + binary_elementwise_args = [] + + for i, arg in enumerate(args): + if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + binary_elementwise_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + binary_elementwise_args.append(arg.index) + else: + binary_elementwise_args.append(arg) + + return binary_elementwise_args + + +def make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + ofm_dtype, + reversed_operands=False, + activation="NONE", + ifm_layout="NHWC", + ifm2_layout="NHWC", + ofm_layout="NHWC", +): + ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm, + ifm2=ifm2, + lut=relay.const([], dtype="int8"), + operator_type=operator_type, + ifm_scale=1, + ifm_zero_point=0, + ifm2_scale=1, + ifm2_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ifm_channels=ifm_channels, + ifm2_channels=ifm2_channels, + reversed_operands=reversed_operands, + activation=activation, + ofm_dtype=ofm_dtype, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ifm_layout=ifm_layout, + ifm2_layout=ifm2_layout, + ofm_layout=ofm_layout, + ) + return ethosu_binary_elementwise diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 4949d6814ab2..a5686c81beb8 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -254,5 +254,342 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) +@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) +@pytest.mark.parametrize( + "pool_shape, strides, activation_function, padding", + [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], +) +def test_ethosu_pooling( + accel_type, + ifm_shape, + pooling_type, + strides, + pool_shape, + activation_function, + padding, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + if pooling_type == "MAX": + op = tf.nn.max_pool(x, pool_shape, strides, padding) + elif pooling_type == "AVG": + op = tf.nn.avg_pool(x, pool_shape, strides, padding) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape}, + dtype_dict={"x": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 3, 4], [1, 1, 1, 1]), + ([1, 1, 1, 1], [1, 2, 3, 4]), + ], +) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +def test_ethosu_binary_elementwise( + accel_type, + operator_type, + ifm_shape, + ifm2_shape, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, lhs, rhs): + if operator_type == "ADD": + op = tf.math.add(lhs, rhs) + elif operator_type == "SUB": + op = tf.math.subtract(lhs, rhs) + elif operator_type == "MUL": + op = tf.math.multiply(lhs, rhs) + elif operator_type == "MIN": + op = tf.math.minimum(lhs, rhs) + elif operator_type == "MAX": + op = tf.math.maximum(lhs, rhs) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape}, + dtype_dict={"ifm": dtype, "ifm2": dtype}, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + output_tolerance=1 if operator_type == "MAX" else 0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 3, 4], [1, 1, 3, 1]), + ([1, 1, 3, 1], [1, 2, 3, 4]), + ], +) +def test_ethosu_left_shift_binary_elemwise( + accel_type, + ifm_shape, + ifm2_shape, +): + dtype = "int32" + + def create_model(): + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + c1 = relay.left_shift(ifm, ifm2) + f = relay.Function([ifm, ifm2], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod + + relay_mod = create_model() + mod = partition_for_ethosu(relay_mod) + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(dtype) + input_data = { + "ifm": np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype), + "ifm2": np.random.randint(0, high=32, size=ifm2_shape, dtype=dtype), + } + output_data = generate_ref_data(relay_mod, input_data) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands, ofm_dtype", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False, "int8"), + ([1, 2, 3, 1], [1, 1, 3, 1], False, "int32"), + ([1, 1, 3, 1], [1, 2, 3, 1], True, "int32"), + ], +) +def test_ethosu_right_shift_binary_elemwise( + ifm_shape, ifm2_shape, reversed_operands, accel_type, ofm_dtype +): + dtype = "int32" + + def create_model(): + ifm_count = int(np.prod(ifm_shape)) + ifm2_count = int(np.prod(ifm2_shape)) + + # Create a "partitioned" Relay function + ifms = relay.var("ifms", shape=[ifm_count + ifm2_count], dtype=dtype) + split = relay.split(ifms, [ifm_count]) + ifm = relay.reshape(split[0], newshape=ifm_shape) + ifm2 = relay.reshape(split[1], newshape=ifm2_shape) + shr_op = infra.make_ethosu_binary_elementwise( + ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype, reversed_operands + ) + + glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0") + func = ( + relay.Function([ifms], shr_op) + .with_attr("Inline", 1) + .with_attr("Compiler", "ethosu") + .with_attr("global_symbol", "tvmgen_default_ethosu_main_0") + .with_attr("Primitive", 1) + ) + mod = tvm.IRModule() + mod[glb_ethosu] = func + mod = relay.transform.InferType()(mod) + + # Main + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + call = relay.Call( + glb_ethosu, + [ + relay.concatenate( + data=( + relay.reshape(ifm, newshape=ifm_count), + relay.reshape(ifm2, newshape=ifm2_count), + ), + axis=0, + ) + ], + ) + mod["main"] = relay.Function([ifm, ifm2], call) + mod = relay.transform.InferType()(mod) + return mod + + mod = create_model() + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(dtype) + in_min, in_max = 18, 19 + lhs = np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype) + rhs = np.random.randint(1, high=2, size=ifm2_shape, dtype=dtype) + input_data = { + "ifm": lhs, + "ifm2": rhs, + } + + if reversed_operands: + lhs = np.broadcast_to(lhs, ifm2_shape) + lhs, rhs = rhs, lhs + else: + rhs = np.broadcast_to(rhs, ifm_shape) + + def rounding_right_shift(lhs, rhs): + r = 1 << (rhs - 1) + return (lhs + r) >> rhs + + output_data = np.array( + [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)] + ).astype(ofm_dtype) + + compiled_model = infra.build_source(mod, input_data, [output_data], accel_type) + + imported_modules = compiled_model[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_model, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index b9a588d4aec0..2a84a23930e4 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -313,7 +313,7 @@ def verify_linear(ext_func, conv2d_params): for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) - mod = legalize.LegalizeEthosUConv2D()(mod) + mod = legalize.LegalizeConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -349,7 +349,7 @@ def create_graph_single_unsupported_ifm_layout( with pytest.raises( tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" ): - mod = legalize.LegalizeEthosUConv2D()(mod) + mod = legalize.LegalizeConv2D()(mod) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) @@ -458,7 +458,290 @@ def verify(ext_func): mod = partition_ethosu_by_table(mod, depthwise_pattern_table) mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( - legalize.EthosuDepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + +@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) +@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) +@pytest.mark.parametrize( + "pool_shape, strides, activation_function, padding", + [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], +) +def test_tflite_pool2d_legalize( + ifm_shape, pooling_type, strides, pool_shape, activation_function, padding +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + if pooling_type == "MAX": + op = tf.nn.max_pool(x, pool_shape, strides, padding) + elif pooling_type == "AVG": + op = tf.nn.avg_pool(x, pool_shape, strides, padding) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + ofm_shape = infra.compute_ofm_shape(ifm_shape, padding, pool_shape, strides) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == ifm_shape + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == ofm_shape + assert op.checked_type.dtype == dtype + assert op.attrs.pooling_type == pooling_type + assert list(op.attrs.strides) == strides + assert list(op.attrs.padding) == infra.compute_padding_shape( + ifm_shape, ofm_shape, padding, pool_shape, strides + ) + assert list(op.attrs.pool_shape) == pool_shape + assert op.attrs.ofm_channels == ifm_shape[3] + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + if pooling_type == "MAX": + rewriter = legalize.MaxPoolingRewriter() + pattern_table = [ + ( + ethosu.MaxPool2DParams.composite_name, + ethosu.qnn_maxpool2d_pattern(), + lambda pat: ethosu.MaxPool2DParams(pat).is_valid(), + ), + ] + elif pooling_type == "AVG": + rewriter = legalize.AvgPoolingRewriter() + pattern_table = [ + ( + ethosu.AvgPool2DParams.composite_name, + ethosu.qnn_avgpool2d_pattern(), + lambda pat: ethosu.AvgPool2DParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape}, + dtype_dict={"x": dtype}, + ) + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False), + ([1, 2, 3, 4], [1, 1, 3, 1], False), + ([1, 1, 3, 1], [1, 2, 3, 4], True), + ], +) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +def test_tflite_binary_elemwise_legalize( + operator_type, + ifm_shape, + ifm2_shape, + reversed_operands, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, y): + if operator_type == "ADD": + op = tf.math.add(x, y) + elif operator_type == "SUB": + op = tf.math.subtract(x, y) + elif operator_type == "MUL": + op = tf.math.multiply(x, y) + elif operator_type == "MIN": + op = tf.math.minimum(x, y) + elif operator_type == "MAX": + op = tf.math.maximum(x, y) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + out_shape = ifm2_shape if reversed_operands else ifm_shape + shapes = [ifm_shape, ifm2_shape] + ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == shapes[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + assert op.attrs.operator_type == operator_type + assert op.attrs.reversed_operands == reversed_operands + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + if operator_type == "ADD": + rewriter = legalize.AddRewriter() + pattern_table = [ + ( + ethosu.AddParams.composite_name, + ethosu.qnn_add_pattern(), + lambda pat: ethosu.AddParams(pat).is_valid(), + ), + ] + elif operator_type == "SUB": + rewriter = legalize.SubRewriter() + pattern_table = [ + ( + ethosu.SubParams.composite_name, + ethosu.qnn_subtract_pattern(), + lambda pat: ethosu.SubParams(pat).is_valid(), + ), + ] + elif operator_type == "MUL": + rewriter = legalize.MulRewriter() + pattern_table = [ + ( + ethosu.MulParams.composite_name, + ethosu.qnn_mul_pattern(), + lambda pat: ethosu.MulParams(pat).is_valid(), + ), + ] + elif operator_type == "MIN": + rewriter = legalize.MinRewriter() + pattern_table = [ + ( + ethosu.MinParams.composite_name, + ethosu.minimum_pattern(), + lambda pat: ethosu.MinParams(pat).is_valid(), + ), + ] + elif operator_type == "MAX": + rewriter = legalize.MaxRewriter() + pattern_table = [ + ( + ethosu.MaxParams.composite_name, + ethosu.maximum_pattern(), + lambda pat: ethosu.MaxParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape, "y": ifm2_shape}, + dtype_dict={"x": dtype, "y": dtype}, + ) + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False), + ([1, 2, 3, 4], [1, 1, 3, 1], False), + ([1, 1, 3, 1], [1, 2, 3, 4], True), + ], +) +def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, reversed_operands): + dtype = "int32" + operator_type = "SHL" + + def create_graph(): + input1 = relay.var("x1", shape=ifm_shape, dtype=dtype) + input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype) + c1 = relay.left_shift(input1, input2) + f = relay.Function([input1, input2], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod + + def verify(ext_func): + out_shape = ifm2_shape if reversed_operands else ifm_shape + shapes = [ifm_shape, ifm2_shape] + ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == shapes[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + assert op.attrs.operator_type == operator_type + assert op.attrs.reversed_operands == reversed_operands + assert str(op.attrs.activation) == "NONE" + + rewriter = legalize.ShlRewriter() + pattern_table = [ + ( + ethosu.ShlParams.composite_name, + ethosu.shl_pattern(), + lambda pat: ethosu.ShlParams(pat).is_valid(), + ), + ] + + mod = create_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] ) verify(mod["tvmgen_default_ethosu_main_0"]) diff --git a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py new file mode 100644 index 000000000000..6dcd9da395cc --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"), + # Broadcast + ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"), + ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_binary_elementwise_single( + ifm_shape, + ifm2_shape, + ifm_channels, + ifm2_channels, + ifm_layout, + ofm_layout, + operator_type, + activation, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + False, + activation, + ifm_layout, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_binary_elementwise_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1 + ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1 + + ifm2_stride_c = 1 + ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1 + ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 else 1 + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[2] + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + + ifm2_stride_w = 16 + ifm2_stride_c = 16 * ifm2_shape[3] + ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3] + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[3] + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ifm_channels if ofm_width > 1 else 1 + ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1) + + serial_binary_elementwise = spec.SerialBinaryElementwise( + ifm=spec.SerialFeatureMap( + data_type=dtype, + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ifm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ifm2=spec.SerialFeatureMap( + data_type=dtype, + height=ifm2_shape[1], + width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + channels=ifm2_channels, + tile_height_0=ifm2_shape[1], + tile_height_1=0, + tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm2_stride_h, + stride_w=ifm2_stride_w, + stride_c=ifm2_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type=dtype, + height=ofm_height, + width=ofm_width, + channels=ifm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + operator_type=operator_type, + reversed_operands=False, + activation=spec.SerialActivation( + op=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ), + ) + + assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise) + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"), + # Broadcast + ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"), + ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("operator_type", ["SHR", "SHL"]) +def test_shift_binary_elementwise_single( + ifm_shape, + ifm2_shape, + ifm_channels, + ifm2_channels, + ifm_layout, + ofm_layout, + operator_type, +): + dtype = "int32" + activation = "NONE" # Only NONE is available if the activation type is int32 + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + False, + "NONE", + ifm_layout, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_binary_elementwise_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1 + ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1 + + ifm2_stride_c = 1 + ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1 + ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 else 1 + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[2] + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + + ifm2_stride_w = 16 + ifm2_stride_c = 16 * ifm2_shape[3] + ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3] + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[3] + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ifm_channels if ofm_width > 1 else 1 + ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1) + + serial_binary_elementwise = spec.SerialBinaryElementwise( + ifm=spec.SerialFeatureMap( + data_type=dtype, + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ifm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ifm2=spec.SerialFeatureMap( + data_type=dtype, + height=ifm2_shape[1], + width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + channels=ifm2_channels, + tile_height_0=ifm2_shape[1], + tile_height_1=0, + tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm2_stride_h, + stride_w=ifm2_stride_w, + stride_c=ifm2_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type=dtype, + height=ofm_height, + width=ofm_width, + channels=ifm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + operator_type=operator_type, + reversed_operands=False, + activation=spec.SerialActivation( + op=activation, + clip_min=0, + clip_max=0, + ), + ) + + assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 76b7ef2a70ee..9590db57dd32 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir -from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants, Convolution2DCompute from .infra import make_ethosu_conv2d @@ -73,5 +73,67 @@ def _get_func(): tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) +# fmt: off +@tvm.script.ir_module +class WeightStream: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 16], dtype="int8") + buffer = T.match_buffer(placeholder_1, [416], dtype="uint8") + buffer_1 = T.match_buffer(placeholder_2, [112], dtype="uint8") + buffer_2 = T.match_buffer(placeholder_3, [272], dtype="uint8") + buffer_3 = T.match_buffer(placeholder_4, [64], dtype="uint8") + # body + placeholder_global = T.allocate([416], "uint8", "global") + placeholder_d_global = T.allocate([112], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_weight_stream(): + def _cascader(cached_func, const_dict, sch): + weight = cached_func.inputs[1] + scale_bias = cached_func.inputs[2] + out = cached_func.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 10) + cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + sch[cache_weight].compute_at(sch[out], co) + sch[cache_scale_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=_cascader) + + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + reference_mod = WeightStream + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py new file mode 100644 index 000000000000..099b9d60c428 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_pooling, get_pooling_args + + +@pytest.mark.parametrize( + "ifm_shape, ofm_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), 40, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_pooling_single( + ifm_shape, + ofm_channels, + ifm_layout, + ofm_layout, + pooling_type, + activation, +): + pool_shape = (3, 2) + strides = (1, 2) + padding = (1, 1, 1, 0) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(pooling), pooling) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_pooling_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ofm_channels if ofm_width > 1 else 1 + ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) + + serial_pooling = spec.SerialPooling( + ifm=spec.SerialFeatureMap( + data_type="int8", + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ofm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type="int8", + height=ofm_height, + width=ofm_width, + channels=ofm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + pooling_type=pooling_type, + pool_shape=spec.SerialKernel( + width=pool_shape[1], + height=pool_shape[0], + stride_w=strides[1], + stride_h=strides[0], + dilation_w=1, + dilation_h=1, + ), + padding=spec.SerialPadding( + top=padding[0], left=padding[1], bottom=padding[2], right=padding[3] + ), + activation=spec.SerialActivation( + op=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ), + upscale="NONE", + ) + + assert data[0] == ["ethosu_pooling"] + list(serial_pooling) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 8240b392a1cf..ab1bad226ae6 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -634,7 +634,7 @@ def populate_ethosu_copy_calls(stmt): for test_case in test_cases: ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): - npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_extern_call(ethosu_copy_call) + npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_call_extern(ethosu_copy_call) assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] assert npu_dma_op.src.length == test_case["ref"][idx]["length"] @@ -675,7 +675,7 @@ def test_assign_addresses(): }, ] - def extract_extern_calls(mod): + def extract_call_extern_list(mod): """This function will obtain all ethosu_conv2d calls from a NPU TIR module Parameters @@ -825,10 +825,10 @@ def check_buffer(address, region, length, buffer_var): buffer_info = tir_to_cs_translator.extract_buffer_info( test_case["tir_module"], test_case["param_dict"] ) - extern_calls = extract_extern_calls(test_case["tir_module"]) + extern_calls = extract_call_extern_list(test_case["tir_module"]) _npu_ops = list() for extern_call in extern_calls: - _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_extern_call(extern_call)) + _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) _npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses( buffer_info, _npu_ops @@ -842,5 +842,510 @@ def check_buffer(address, region, length, buffer_var): assert np.prod(constant_tensor_read_mask) == 1 +# fmt: off +"""A ethosu_pooling tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuPooling: + @T.prim_func + def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 5, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "NONE", dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +def test_translate_ethosu_pooling(): + def extract_ethosu_pooling_extern_call(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_pooling_calls = list() + + def populate_ethosu_pooling_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_pooling" + ): + ethosu_pooling_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_pooling_calls) + return ethosu_pooling_calls[0] + + pooling_call = extract_ethosu_pooling_extern_call(SingleEthosuPooling) + npu_op = tir_to_cs_translator.translate_ethosu_pooling(pooling_call) + + assert npu_op.ifm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D(27, 3, 1) + # Compare OFM + assert npu_op.ofm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ofm.shape == vapi.NpuShape3D(5, 5, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D(15, 3, 1) + # Compare pooling_type + assert npu_op.sub_op_type == vapi.NpuPoolingOp.AVERAGE + # Compare kernel and padding + assert ( + npu_op.kernel.__dict__ + == vapi.NpuKernel(w=2, h=3, stride_x=2, stride_y=1, dilation_x=1, dilation_y=1).__dict__ + ) + assert npu_op.padding == vapi.NpuPadding(top=1, left=1, bottom=1, right=0) + # Compare activation + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + # Compare ifm upscaling + assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE + + +# fmt: off +"""A ethosu_binary_elementwise ADD tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseAdd: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer( + placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ) + ethosu_write_2 = T.match_buffer( + ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ) + # body + T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, dtype="int8")) + + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise SUB tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseSub: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise MUL tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMul: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MIN tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMin: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise Max tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMax: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHR tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShr: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHL tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShl: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", "SHR", "SHL"]) +def test_translate_ethosu_binary_elementwise(operator_type): + if operator_type == "SHR" or operator_type == "SHL": + data_type = vapi.NpuDataType.INT32 + data_type_bytes = 4 + else: + data_type = vapi.NpuDataType.INT8 + data_type_bytes = 1 + + def extract_ethosu_binary_elementwise_call_extern(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_binary_elementwise_calls = list() + + def populate_ethosu_binary_elementwise_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_binary_elementwise" + ): + ethosu_binary_elementwise_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_binary_elementwise_calls) + return ethosu_binary_elementwise_calls[0] + + if operator_type == "ADD": + binary_elementwise = SingleEthosuBinaryElementwiseAdd + elif operator_type == "SUB": + binary_elementwise = SingleEthosuBinaryElementwiseSub + elif operator_type == "MUL": + binary_elementwise = SingleEthosuBinaryElementwiseMul + elif operator_type == "MIN": + binary_elementwise = SingleEthosuBinaryElementwiseMin + elif operator_type == "MAX": + binary_elementwise = SingleEthosuBinaryElementwiseMax + elif operator_type == "SHR": + binary_elementwise = SingleEthosuBinaryElementwiseShr + elif operator_type == "SHL": + binary_elementwise = SingleEthosuBinaryElementwiseShl + binary_elementwise_call = extract_ethosu_binary_elementwise_call_extern(binary_elementwise) + npu_op = tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call) + + # Compare IFM + assert npu_op.ifm.data_type == data_type + assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare IFM2 + assert npu_op.ifm2.data_type == data_type + assert npu_op.ifm2.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm2.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare OFM + assert npu_op.ofm.data_type == data_type + assert npu_op.ofm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare op type + if operator_type == "ADD": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL + # Compare reversed_operands + assert npu_op.reversed_operands == False + # Compare activation + if operator_type == "SHR": + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + + +# fmt: off +"""A ethosu_binary_elementwise ADD with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseAddBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise SUB with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseSubBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise MUL with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMulBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MIN with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMinBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MAX with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMaxBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHR with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShrBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHL with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShlBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", "SHR", "SHL"]) +def test_translate_ethosu_binary_elementwise_broadcasting(operator_type): + if operator_type == "SHR" or operator_type == "SHL": + data_type = vapi.NpuDataType.INT32 + data_type_bytes = 4 + else: + data_type = vapi.NpuDataType.INT8 + data_type_bytes = 1 + + def extract_ethosu_binary_elementwise_broadcasting_call_extern(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_binary_elementwise_calls = list() + + def populate_ethosu_binary_elementwise_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_binary_elementwise" + ): + ethosu_binary_elementwise_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_binary_elementwise_calls) + return ethosu_binary_elementwise_calls[0] + + if operator_type == "ADD": + binary_elementwise = SingleEthosuBinaryElementwiseAddBroadcasting + elif operator_type == "SUB": + binary_elementwise = SingleEthosuBinaryElementwiseSubBroadcasting + elif operator_type == "MUL": + binary_elementwise = SingleEthosuBinaryElementwiseMulBroadcasting + elif operator_type == "MIN": + binary_elementwise = SingleEthosuBinaryElementwiseMinBroadcasting + elif operator_type == "MAX": + binary_elementwise = SingleEthosuBinaryElementwiseMaxBroadcasting + elif operator_type == "SHR": + binary_elementwise = SingleEthosuBinaryElementwiseShrBroadcasting + elif operator_type == "SHL": + binary_elementwise = SingleEthosuBinaryElementwiseShlBroadcasting + binary_elementwise_call = extract_ethosu_binary_elementwise_broadcasting_call_extern( + binary_elementwise + ) + npu_op = tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call) + + # Compare IFM + assert npu_op.ifm.data_type == data_type + assert npu_op.ifm.shape == vapi.NpuShape3D(2, 3, 4) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D( + 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes + ) + # Compare IFM2 + assert npu_op.ifm2.data_type == data_type + assert npu_op.ifm2.shape == vapi.NpuShape3D(1, 3, 1) + assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm2.strides == vapi.NpuShape3D( + 1 * data_type_bytes, 1 * data_type_bytes, 1 * data_type_bytes + ) + # Compare OFM + assert npu_op.ofm.data_type == data_type + assert npu_op.ofm.shape == vapi.NpuShape3D(2, 3, 4) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D( + 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes + ) + # Compare op type + if operator_type == "ADD": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL + # Compare reversed_operands + assert npu_op.reversed_operands == True + # Compare activation + + if operator_type == "SHR": + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 47fddad773b2..e068439fcee5 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -18,10 +18,13 @@ pytest.importorskip("ethosu.vela") +from tvm import relay, TVMError from tvm import relay from tvm.relay.testing import run_opt_pass from .infra import make_ethosu_conv2d from .infra import make_ethosu_depthwise_conv2d +from .infra import make_ethosu_pooling +from .infra import make_ethosu_binary_elementwise @pytest.mark.parametrize( @@ -54,9 +57,37 @@ def test_ethosu_conv2d_type_inference( ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) - f = relay.Function([ifm], conv2d) - f = run_opt_pass(f, relay.transform.InferType()) - assert tuple(f.body.checked_type.shape) == ofm_shape + func = relay.Function([ifm], conv2d) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + + +@pytest.mark.parametrize( + "ifm_dtype,weight_dtype,scale_bias_dtype", + [("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")], +) +def test_ethosu_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype): + ifm_channels = 55 + ofm_channels = 122 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype) + conv2d = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + weight_dtype=weight_dtype, + scale_bias_dtype=scale_bias_dtype, + ) + func = relay.Function([ifm], conv2d) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) @pytest.mark.parametrize( @@ -87,9 +118,228 @@ def test_ethosu_depthwise_conv2d_type_inference( ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) - f = relay.Function([ifm], depthwise_conv2d) - f = run_opt_pass(f, relay.transform.InferType()) - assert tuple(f.body.checked_type.shape) == ofm_shape + func = relay.Function([ifm], depthwise_conv2d) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + + +@pytest.mark.parametrize( + "ifm_dtype,weight_dtype,scale_bias_dtype", + [("float32", "int8", "uint8"), ("int8", "float32", "uint8"), ("int8", "int8", "float32")], +) +def test_ethosu_depthwise_conv2d_invalid_dtypes(ifm_dtype, weight_dtype, scale_bias_dtype): + channels = 55 + kernel_shape = (3, 2) + padding = (0, 1, 2, 3) + strides = (1, 2) + dilation = (2, 1) + dilation = (2, 1) + ifm = relay.var("ifm", shape=(1, 56, 72, 55), dtype=ifm_dtype) + depthwise_conv2d = make_ethosu_depthwise_conv2d( + ifm, + channels, + kernel_shape, + padding, + strides, + dilation, + weight_dtype=weight_dtype, + scale_bias_dtype=scale_bias_dtype, + ) + func = relay.Function([ifm], depthwise_conv2d) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 56, 38, 55), "NHWC"), ((1, 56, 4, 38, 16), "NHCWB16")] +) +def test_ethosu_pooling_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + pooling_type = "AVG" + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + func = relay.Function([ifm], pooling) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + assert func.body.checked_type.dtype == dtype + + +def test_ethosu_pooling_invalid_pooling_type(): + invalid_pooling_type = "A" + dtype = "int8" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype) + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + invalid_pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ) + func = relay.Function([ifm], pooling) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +def test_ethosu_pooling_invalid_dtype(): + invalid_dtype = "int32" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=invalid_dtype) + pooling_type = "MAX" + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ) + func = relay.Function([ifm], pooling) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +def test_ethosu_binary_elementwise_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype) + operator_type = "ADD" + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + ifm_layout=ifm_layout, + ifm2_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + assert func.body.checked_type.dtype == dtype + + +def test_ethosu_binary_elementwise_invalid_operator_type(): + invalid_operator_type = "A" + ifm_shape = [1, 4, 5, 33] + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + invalid_operator_type, + dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +def test_ethosu_binary_elementwise_invalid_data_types(): + dtype = "int8" + dtype2 = "int32" + operator_type = "ADD" + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype2) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize("operator_type", ["MIN", "MAX"]) +def test_ethosu_binary_elementwise_min_max_invalid_data_type(operator_type): + invalid_dtype = "int32" + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + invalid_dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize("invalid_dtype", ["int8", "uint8"]) +@pytest.mark.parametrize("operator_type", ["RHS", "SHR"]) +def test_ethosu_binary_elementwise_shift_invalid_data_type(invalid_dtype, operator_type): + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + invalid_dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) if __name__ == "__main__": diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md index a47c3438bf57..674e1af6029f 100644 --- a/tests/python/contrib/test_hexagon/README.md +++ b/tests/python/contrib/test_hexagon/README.md @@ -17,8 +17,7 @@ Documents manual TE schedule to illustrate Hexagon operator slicing. -# High Level Notes - +High Level Notes: * Using float32 (for now) so that tests will pass on CPU * Using global storage scope (for now) which means "cache" reads and writes from global, to global * TIR is pending changes from the work-in-progress layout RFC @@ -33,485 +32,6 @@ Documents manual TE schedule to illustrate Hexagon operator slicing. * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in * Using `rh` and `rw` (reduction height / width) to denote filter height and width -# Calling Convention - -TODO: Map this packed string to parameters -conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm - -# Baseline conv2d - -This is a baseline 1x1 conv2d schedule for Hexagon. - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----------- | -| Batch | 1 | -| Filter | 1x1 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Output Ch | 128 | -| Stride | 1 | -| Padding | 0 | -| Layout | NHWC8h8w32c | - -## Assumptions - -* Pattern matching for microkernels is not senstive to cache reads and writes between the outer height (ho) and outer width (wo) loops. - -## To Do - -* n/a - -## Annotated TIR - -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; - allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; - - for (ko.outer: int32, 0, 4) { - for (ho.outer: int32, 0, 8) { - - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - - // filter cache read - for (co: int32, 0, 2) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } - - // compute - for (wo.c: int32, 0, 8) { - - // init output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // convolution - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } // end wo.c - - // cache write - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.outer - } // end ko.outer -} -``` - -# Split on Channel Out and Height - "Full Output Slice" - -Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. - -The key changes in TIR versus the above are... - -1) Increased cache allocations: - -``` - // input cache grows by factor of h_split = 2 - allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // filter cache grows by factor of k_split = 2 - allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; - - // output cache grows by factor of h_split * k_split = 4 - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; -``` - -2) Outer loop splits using k_split and h_split factors - -``` - // ko.outer = outer loop split on ko using k_split factor - for (ko.outer: int32, 0, 2) { - // ho.outer = outer loop split on ho using h_split factor - for (ho.outer: int32, 0, 4) { -``` - -3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. -``` - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { -``` - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----------- | -| Batch | 1 | -| Filter | 1x1 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Output Ch | 128 | -| Stride | 1 | -| Padding | 0 | -| Layout | NHWC8h8w32c | -| k_split | 2 | -| h_split | 2 | - -## Assumptions - -* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels. - -## To Do - -* n/a - -## Annotated TIR - -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - - // input cache grows by factor of h_split = 2 - allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // filter cache grows by factor of k_split = 2 - allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; - - // output cache grows by factor of h_split * k_split = 4 - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // ko.outer = outer loop split on ko using k_split factor - for (ko.outer: int32, 0, 2) { - // ho.outer = outer loop split on ho using h_split factor - for (ho.outer: int32, 0, 4) { - - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - } // end ho.inner - - // filter cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 2) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } - } // end ko.inner - - // compute - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - - // init output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // convolution - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } // end wo.c - } // end ho.c.inner - } // end ko.c.inner - - // cache write - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.inner - } // end ko.inner - } // end ho.outer - } // end ko.outer -} -``` - -# 3x3 conv2d (no padding) - -Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. - -The key changes in TIR versus the above are... - -1) Increased input cache size to hold the vertically adjacent slice - -``` - // input cache grows to hold vertically adjacent slice - allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; -``` - -2) Loop over `ho.inner` upper bound increased from `h_split` = 2 to `h_split + 1` = 3 - -``` - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 3) { - if (((ho.outer*2) + ho.inner) < 8) { -``` - -The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. - - -3) Increased filter cache size to hold 3x3 filter - -``` - // filter cache grows to hold larger 3x3 filter - allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; -``` - -4) Loops over `rh` and `rw` the kernel spatial dimensions: -``` - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { -``` - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-2-1-64-64-128-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----------- | -| Batch | 1 | -| Filter | 3x3 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Output Ch | 128 | -| Stride | 1 | -| Padding | 0 | -| Layout | NHWC8h8w32c | -| h_split | 2 | - -## Assumptions - -* n/a - -## To Do - -There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: - -| ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | -| -------- | -------- | ------------------------------------- | -| 0 | 0 | 0 | -| 0 | 1 | 32k | -| 0 | 2 | 64k (vertical adjacent slice loop 0) | -| 1 | 0 | 64k | -| 1 | 1 | 96k | -| 1 | 2 | 128k (vertical adjacent slice loop 1) | -| 2 | 0 | 128k | -| 2 | 1 | 160k | -| 2 | 2 | 192k (vertical adjacent slice loop 2) | -| 3 | 0 | 192k | -| 3 | 1 | 224k | -| 3 | 2 | (No vertical adjacent slice loop 3) | - -Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` = N) is reused in loop N + 1. - -## Annotated TIR +[Conv2d](test_conv2d_blocked.md) -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - // input cache grows to hold vertically adjacent slice - allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; - // filter cache grows to hold larger 3x3 filter - allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (ho.inner: int32, 0, 3) { - if (((ho.outer*2) + ho.inner) < 8) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - } - } - // filter cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 2) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } // end rw - } // end rh - } - } - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * - (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } // end rw - } // end rh - } - } - } - } // end wo.c - } // end ho.c.inner - } // end ko.c.inner - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.inner - } // end ko.inner - } // end ho.outer - } // end ko.outer -}``` \ No newline at end of file +[Conv2d -> Conv2d](test_conv2d_conv2d.md) \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 193a8630c3d2..4befcc62556f 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -39,16 +39,25 @@ def get_packed_activation_layout(shape_nhwc, block_shape, packed_C=True): return shape +def get_block_shape(): + return 8, 8, 32 + + +def get_filter_block_shape(): + return 8, 32, 4 + + def get_packed_filter_layout(out_channel, in_channel, kernel_h, kernel_w): - out_factor, in_first_factor, in_second_factor = 32, 32, 4 + filter_Cio, filter_Ki, filter_Cii = get_filter_block_shape() + filter_Ci = filter_Cio * filter_Cii return ( - int(ceildiv(out_channel, out_factor)), - int(ceildiv(in_channel, in_first_factor)), + int(ceildiv(out_channel, filter_Ki)), + int(ceildiv(in_channel, filter_Ci)), kernel_h, kernel_w, - in_first_factor // in_second_factor, - out_factor, - in_second_factor, + filter_Cio, + filter_Ki, + filter_Cii, ) @@ -71,10 +80,6 @@ def build_and_run(inputs, func, target, target_host, *args, **kwargs): return tensors[-1].asnumpy() -def get_block_shape(): - return 8, 8, 32 - - def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, out_channels): assert len(shape_nhwc) == 4 kernel = [] @@ -86,3 +91,41 @@ def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, o (shape_nhwc[2] - kernel[1] + padding[2] + padding[3]) // strides[1] + 1, out_channels, ) + + +def verify_conv2d(output, ref_output, dtype): + # nhwc8h8w32c + if len(output.shape) == 7: + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # nhwhwc + else: + # nhwhwc -> nhwc + output = output.transpose(0, 1, 3, 2, 4, 5).reshape( + output.shape[0], + output.shape[1] * output.shape[3], + output.shape[2] * output.shape[4], + output.shape[5], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.md b/tests/python/contrib/test_hexagon/test_conv2d_blocked.md new file mode 100644 index 000000000000..aebaaffde939 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.md @@ -0,0 +1,496 @@ + + + + + + + + + + + + + + + + + +Hexagon conv2d schedules + +# Baseline conv2d + +This is a baseline 1x1 conv2d schedule for Hexagon. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_nhwc8h8w32c-1-1-0-float32-1-1-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | + +## Assumptions + +* Pattern matching for microkernels is not senstive to cache reads and writes between the outer height (ho) and outer width (wo) loops. + +## To Do + +* n/a + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; + + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + + // filter cache read + for (co: int32, 0, 2) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } + + // compute + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } // end wo.c + + // cache write + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.outer + } // end ko.outer +} +``` + +# Split on Channel Out and Height - "Full Output Slice" + +Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. + +The key changes in TIR versus the above are... + +1) Increased cache allocations: + +``` + // input cache grows by factor of h_split = 2 + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; +``` + +2) Outer loop splits using k_split and h_split factors + +``` + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { +``` + +3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. +``` + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { +``` + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_nhwc8h8w32c-1-1-0-float32-2-2-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 1x1 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| k_split | 2 | +| h_split | 2 | + +## Assumptions + +* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels. + +## To Do + +* n/a + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + + // input cache grows by factor of h_split = 2 + allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } // end ho.inner + + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } + } // end ko.inner + + // compute + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + + // cache write + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +} +``` + +# 3x3 conv2d (no padding) + +Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. + +The key changes in TIR versus the above are... + +1) Increased input cache size to hold the vertically adjacent slice + +``` + // input cache grows to hold vertically adjacent slice + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; +``` + +2) Loop over `ho.inner` upper bound increased from `h_split` = 2 to `h_split + 1` = 3 + +``` + for (ho.outer: int32, 0, 4) { + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { +``` + +The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + + +3) Increased filter cache size to hold 3x3 filter + +``` + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; +``` + +4) Loops over `rh` and `rw` the kernel spatial dimensions: +``` + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { +``` + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_nhwc8h8w32c-3-1-0-float32-2-2-1-64-64-128-llvm]" + +## Parameters + +| Parameter | Value | +| --------- | ----------- | +| Batch | 1 | +| Filter | 3x3 | +| Spatial | 64x64 | +| Input Ch | 64 | +| Output Ch | 128 | +| Stride | 1 | +| Padding | 0 | +| Layout | NHWC8h8w32c | +| h_split | 2 | + +## Assumptions + +* n/a + +## To Do + +There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: + +| ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | +| -------- | -------- | ------------------------------------- | +| 0 | 0 | 0 | +| 0 | 1 | 32k | +| 0 | 2 | 64k (vertical adjacent slice loop 0) | +| 1 | 0 | 64k | +| 1 | 1 | 96k | +| 1 | 2 | 128k (vertical adjacent slice loop 1) | +| 2 | 0 | 128k | +| 2 | 1 | 160k | +| 2 | 2 | 192k (vertical adjacent slice loop 2) | +| 3 | 0 | 192k | +| 3 | 1 | 224k | +| 3 | 2 | (No vertical adjacent slice loop 3) | + +Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` = N) is reused in loop N + 1. + +## Annotated TIR + +``` +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + // input cache grows to hold vertically adjacent slice + allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } + } + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } // end rw + } // end rh + } + } + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * + (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } // end rw + } // end rh + } + } + } + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } + } + } + } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +}``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index 07696b51a327..f3da3e1f8c09 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -26,138 +26,19 @@ build_and_run, get_block_shape, get_conv2d_nhwc_shape, + get_filter_block_shape, get_packed_filter_layout, get_packed_activation_layout, + verify_conv2d, ) import numpy as np import pytest -def conv2d_logical( +def conv2d_nhwc8h8w32c( shape_nhwc, - shape_oihw, - kernel_size, - stride, - padding, - dtype, - storage_scope="global", -): - """ - Conv2d TE wherein both input activation and filter tensors - are defined with their logical NHWC/OIHW shapes, respectively. - The packed physical layout for the activation and filter are: - Activation: nhwc8h8w32c - Filter: oihw8i32o4i - """ - assert kernel_size == tuple(shape_oihw[2:]) - - block_shape = get_block_shape() - block_H, block_W, block_C = block_shape - shape = get_packed_activation_layout(shape_nhwc, block_shape) - logical_output_shape = get_conv2d_nhwc_shape( - shape_nhwc, kernel_size, stride, padding, [1, 1], shape_oihw[0] - ) - output_shape = get_packed_activation_layout(logical_output_shape, block_shape) - - N, H, W, C = shape_nhwc - X = te.placeholder(shape_nhwc, dtype=dtype) - # Combination of padding required by conv2d operator and padding to evenly divisible - # number of blocks. Note that this padding should be inlined in the schedule so - # as to avoid input copying. - pad_h = (block_H - ((H + padding[1]) % block_H)) % block_H - pad_w = (block_W - ((W + padding[3]) % block_W)) % block_W - X_pad = topi.nn.pad(X, [0, padding[0], padding[2], 0], [0, pad_h, pad_w, 0], pad_value=0) - # Calculate packed layout - X_packed = te.compute( - shape, - lambda n, ho, wo, co, hi, wi, ci: X_pad[ - n, ho * block_H + hi, wo * block_W + wi, co * block_C + ci - ], - ) - - # Filter shape using KCRS (OIHW) notation - K, C, R, S = shape_oihw - filter_Ki, filter_Ci, filter_Cii = 32, 32, 4 - shape_filter = get_packed_filter_layout(K, C, R, S) - filt = te.placeholder(shape_oihw, dtype=dtype) - # Channel padding to multiples of 32 - pad_c = (filter_Ci - (C % filter_Ci)) % filter_Ci - pad_k = (filter_Ki - (K % filter_Ki)) % filter_Ki - filt_pad = topi.nn.pad( - filt, [0, 0, 0, 0], [pad_k, pad_c, R, S], pad_value=0, name="padded_filter" - ) - filt_packed = te.compute( - shape_filter, - lambda ko, co, r, s, cio, ki, cii: filt_pad[ - ko * filter_Ki + ki, co * filter_Ci + cio * filter_Cii + cii, r, s - ], - name="packed_filter", - ) - - rh = te.reduce_axis((0, kernel_size[0]), name="rh") - rw = te.reduce_axis((0, kernel_size[1]), name="rw") - rc = te.reduce_axis((0, C), name="rc") - - def compute(n, ho, wo, ko, hi, wi, ki): - # Construct blockized strided conv2d height index - h = ho * block_H + hi - h_contig = h * stride[0] + rh - h_block_id = h_contig // block_H - h_block_offset = h_contig % block_H - - # Construct blockized strided conv2d width index - w = wo * block_W + wi - w_contig = w * stride[1] + rw - w_block_id = w_contig // block_W - w_block_offset = w_contig % block_W - - # Construct blockized conv2d channel index - c_block_id = rc // block_C - c_block_offset = rc % block_C - - # Construct flat filter input channel indices - rco = rc // filter_Ci - rcio = (rc % filter_Ci) // filter_Cii - rcii = rc % filter_Cii - - return te.sum( - X_packed[ - n, - h_block_id, - w_block_id, - c_block_id, - h_block_offset, - w_block_offset, - c_block_offset, - ] - * filt_packed[ko, rco, rh, rw, rcio, ki, rcii], - axis=[rh, rw, rc], - ) - - Y = te.compute(output_shape, compute) - s = te.create_schedule(Y.op) - - # Ensure the padding and array packing is performed inline - s[X_pad].compute_inline() - s[X_packed].compute_inline() - - s[filt_pad].compute_inline() - s[filt_packed].compute_inline() - - binds = {} - if storage_scope and storage_scope != "global": - with tvm.transform.PassContext(): - Xb = tvm.tir.decl_buffer(shape, name="Xb", dtype=dtype, scope=storage_scope) - Yb = tvm.tir.decl_buffer(output_shape, name="Yb", dtype=dtype, scope=storage_scope) - binds = {X: Xb, Y: Yb} - - return (s, [X, filt, Y], binds) - - -def conv2d_packed_filter( - shape_nhwc, - shape_oihw8i32o4i, + shape_filter, kernel_size, stride, padding, @@ -168,11 +49,19 @@ def conv2d_packed_filter( ): """ Conv2d TE wherein the input activation is defined by its - logical NHWC shape, but the filter is provided in the - packed layout oihw8i32o4i. The physical packed layout used - for the activation is: nhwc8h8w32c + logical NHWC shape. The filter is provided in either its + logical (OIHW) or physical packed (oihw8i32o4i) shape. The + physical packed layout for the input / output is nhwc8h8w32c. """ - assert kernel_size == tuple(shape_oihw8i32o4i[2:4]) + + # oihw8i32o41 + if len(shape_filter) == 7: + assert kernel_size == tuple(shape_filter[2:4]) + out_channels = shape_filter[0] * shape_filter[5] + # oihw + else: + assert kernel_size == tuple(shape_filter[2:]) + out_channels = shape_filter[0] block_shape = get_block_shape() block_H, block_W, block_C = block_shape @@ -183,7 +72,7 @@ def conv2d_packed_filter( stride, padding, [1, 1], - shape_oihw8i32o4i[0] * shape_oihw8i32o4i[5], + out_channels, ) output_shape = get_packed_activation_layout(logical_output_shape, block_shape) @@ -195,11 +84,10 @@ def conv2d_packed_filter( # as to avoid input copying. pad_h = (block_H - ((H + padding[1]) % block_H)) % block_H pad_w = (block_W - ((W + padding[3]) % block_W)) % block_W - X_pad = topi.nn.pad(X, [0, padding[0], padding[2], 0], [0, pad_h, pad_w, 0], pad_value=0) + # Calculate packed layout packed_shape = get_packed_activation_layout(X_pad.shape, block_shape) - X_packed = te.compute( packed_shape, lambda n, ho, wo, co, hi, wi, ci: X_pad[ @@ -207,13 +95,38 @@ def conv2d_packed_filter( ], ) - # Filter shape using KCRS (OIHW) notation - filter_Ki, filter_Ci, filter_Cii = 32, 32, 4 - assert shape_oihw8i32o4i[-1] == filter_Cii - assert shape_oihw8i32o4i[-2] == filter_Ki - assert shape_oihw8i32o4i[-3] == filter_Ci // filter_Cii + filter_Cio, filter_Ki, filter_Cii = get_filter_block_shape() + filter_Ci = filter_Cio * filter_Cii + + if len(shape_filter) == 7: + assert shape_filter[-1] == filter_Cii + assert shape_filter[-2] == filter_Ki + assert shape_filter[-3] == filter_Cio - filt_packed = te.placeholder(shape_oihw8i32o4i, dtype=dtype) + filt = te.placeholder(shape_filter, dtype=dtype) + filt_packed = filt + + else: + filt = te.placeholder(shape_filter, dtype=dtype) + + # get logical filter shape KCRS (OIHW) + K, C, R, S = shape_filter + + # Channel padding to multiples of 32 + pad_c = (filter_Ci - (C % filter_Ci)) % filter_Ci + pad_k = (filter_Ki - (K % filter_Ki)) % filter_Ki + filt_pad = topi.nn.pad( + filt, [0, 0, 0, 0], [pad_k, pad_c, R, S], pad_value=0, name="padded_filter" + ) + + shape_packed_filter = get_packed_filter_layout(K, C, R, S) + filt_packed = te.compute( + shape_packed_filter, + lambda ko, co, r, s, cio, ki, cii: filt_pad[ + ko * filter_Ki + ki, co * filter_Ci + cio * filter_Cii + cii, r, s + ], + name="packed_filter", + ) rh = te.reduce_axis((0, kernel_size[0]), name="rh") rw = te.reduce_axis((0, kernel_size[1]), name="rw") @@ -262,6 +175,11 @@ def compute(n, ho, wo, ko, hi, wi, ki): s[X_pad].compute_inline() s[X_packed].compute_inline() + # if we did filter padding, packing + if filt != filt_packed: + s[filt_pad].compute_inline() + s[filt_packed].compute_inline() + # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) Fl = s.cache_read(filt_packed, storage_scope, [Y]) @@ -314,12 +232,12 @@ def compute(n, ho, wo, ko, hi, wi, ki): Yb = tvm.tir.decl_buffer(output_shape, name="Yb", dtype=dtype, scope=storage_scope) binds = {X: Xb, Y: Yb} - return (s, [X, filt_packed, Y], binds) + return (s, [X, filt, Y], binds) -def conv2d_packed_filter_nhwhwc( +def conv2d_nhw8h8wc( shape_nhwc, - shape_oihw8i32o4i, + shape_filter, kernel_size, stride, padding, @@ -330,12 +248,19 @@ def conv2d_packed_filter_nhwhwc( ): """ Conv2d TE wherein the input activation is defined by its - logical NHWC shape, but the filter is provided in the - packed layout oihw8i32o4i. The physical packed layout used - for the activation is: nhw8h8wc - + logical NHWC shape. The filter is provided in either its + logical (OIHW) or physical packed (oihw8i32o4i) shape. The + physical packed layout for the input / output is nhw8h8wc. """ - assert kernel_size == tuple(shape_oihw8i32o4i[2:4]) + + # oihw8i32o41 + if len(shape_filter) == 7: + assert kernel_size == tuple(shape_filter[2:4]) + out_channels = shape_filter[0] * shape_filter[5] + # oihw + else: + assert kernel_size == tuple(shape_filter[2:]) + out_channels = shape_filter[0] block_shape = get_block_shape() block_H, block_W, block_C = block_shape @@ -346,8 +271,9 @@ def conv2d_packed_filter_nhwhwc( stride, padding, [1, 1], - shape_oihw8i32o4i[0] * shape_oihw8i32o4i[5], + out_channels, ) + output_shape = get_packed_activation_layout(logical_output_shape, block_shape, packed_C=False) N, H, W, C = shape_nhwc @@ -358,19 +284,45 @@ def conv2d_packed_filter_nhwhwc( pad_h = (block_H - ((H + padding[1]) % block_H)) % block_H pad_w = (block_W - ((W + padding[3]) % block_W)) % block_W X_pad = topi.nn.pad(X, [0, padding[0], padding[2], 0], [0, pad_h, pad_w, 0], pad_value=0) + # Calculate packed layout packed_shape = get_packed_activation_layout(X_pad.shape, block_shape, packed_C=False) X_packed = te.compute( packed_shape, lambda n, ho, wo, hi, wi, c: X_pad[n, ho * block_H + hi, wo * block_W + wi, c] ) - # Filter shape using KCRS (OIHW) notation - filter_Ki, filter_Ci, filter_Cii = 32, 32, 4 - assert shape_oihw8i32o4i[-1] == filter_Cii - assert shape_oihw8i32o4i[-2] == filter_Ki - assert shape_oihw8i32o4i[-3] == filter_Ci // filter_Cii + filter_Cio, filter_Ki, filter_Cii = get_filter_block_shape() + filter_Ci = filter_Cio * filter_Cii + + if len(shape_filter) == 7: + assert shape_filter[-1] == filter_Cii + assert shape_filter[-2] == filter_Ki + assert shape_filter[-3] == filter_Cio + + filt = te.placeholder(shape_filter, dtype=dtype) + filt_packed = filt + + else: + filt = te.placeholder(shape_filter, dtype=dtype) + + # get logical filter shape KCRS (OIHW) + K, C, R, S = shape_filter + + # Channel padding to multiples of 32 + pad_c = (filter_Ci - (C % filter_Ci)) % filter_Ci + pad_k = (filter_Ki - (K % filter_Ki)) % filter_Ki + filt_pad = topi.nn.pad( + filt, [0, 0, 0, 0], [pad_k, pad_c, R, S], pad_value=0, name="padded_filter" + ) - filt_packed = te.placeholder(shape_oihw8i32o4i, dtype=dtype) + shape_packed_filter = get_packed_filter_layout(K, C, R, S) + filt_packed = te.compute( + shape_packed_filter, + lambda ko, co, r, s, cio, ki, cii: filt_pad[ + ko * filter_Ki + ki, co * filter_Ci + cio * filter_Cii + cii, r, s + ], + name="packed_filter", + ) rh = te.reduce_axis((0, kernel_size[0]), name="rh") rw = te.reduce_axis((0, kernel_size[1]), name="rw") @@ -411,6 +363,11 @@ def compute(n, ho, wo, hi, wi, k): s[X_pad].compute_inline() s[X_packed].compute_inline() + # if we did filter padding, packing + if filt != filt_packed: + s[filt_pad].compute_inline() + s[filt_packed].compute_inline() + # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) Fl = s.cache_read(filt_packed, storage_scope, [Y]) @@ -479,7 +436,7 @@ def compute(n, ho, wo, hi, wi, k): Yb = tvm.tir.decl_buffer(output_shape, name="Yb", dtype=dtype, scope=storage_scope) binds = {X: Xb, Y: Yb} - return (s, [X, filt_packed, Y], binds) + return (s, [X, filt, Y], binds) class BaseConv2d: @@ -495,9 +452,23 @@ class BaseConv2d: h_split_factor = tvm.testing.parameter(1, 2) -class TestConv2dLogical(BaseConv2d): +class TestConv2dLogicalFilter(BaseConv2d): + conv2d_impl = tvm.testing.parameter(conv2d_nhwc8h8w32c, conv2d_nhw8h8wc) + @tvm.testing.parametrize_targets("llvm") - def test_conv2d(self, shape_nhwc, shape_oihw, kernel, stride, pad, dtype, target): + def test_conv2d( + self, + conv2d_impl, + shape_nhwc, + shape_oihw, + kernel, + stride, + pad, + dtype, + target, + k_split_factor, + h_split_factor, + ): inputs = [ np.random.uniform(0, 255, size=shape_nhwc).astype(dtype), np.random.uniform(0, 255, size=shape_oihw).astype(dtype), @@ -506,44 +477,24 @@ def test_conv2d(self, shape_nhwc, shape_oihw, kernel, stride, pad, dtype, target ref_output = testing.conv2d_nhwc_python(inputs[0], np_filter, stride, pad) output = build_and_run( inputs, - conv2d_logical, + conv2d_impl, target, target, shape_nhwc=shape_nhwc, - shape_oihw=shape_oihw, + shape_filter=shape_oihw, kernel_size=(kernel, kernel), stride=(stride, stride), padding=(pad, pad, pad, pad), dtype=dtype, + k_split_factor=k_split_factor, + h_split_factor=h_split_factor, ) - # nhwc8h8w32c -> nhwc - output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( - output.shape[0], - output.shape[1] * output.shape[4], - output.shape[2] * output.shape[5], - output.shape[3] * output.shape[6], - ) - - # slice output to match ref_output shape - # e.g. 8x8 spatial 3x3 filter = 6x6 ref output - # but still 8x8 output given the blocked layout - output = output[ - 0 : ref_output.shape[0] : 1, - 0 : ref_output.shape[1] : 1, - 0 : ref_output.shape[2] : 1, - 0 : ref_output.shape[3] : 1, - ] - - if "int" in dtype: - tol = {"atol": 0, "rtol": 0} - elif dtype == "float32": - tol = {"rtol": 1e-4, "atol": 2e-4} - tvm.testing.assert_allclose(output, ref_output, **tol) + verify_conv2d(output, ref_output, dtype) class TestConv2dPackedFilter(BaseConv2d): - conv2d_impl = tvm.testing.parameter(conv2d_packed_filter, conv2d_packed_filter_nhwhwc) + conv2d_impl = tvm.testing.parameter(conv2d_nhwc8h8w32c, conv2d_nhw8h8wc) @tvm.testing.parametrize_targets("llvm") @pytest.mark.skip("Skip due to being flaky on i386.") @@ -575,7 +526,7 @@ def test_conv2d( target, target, shape_nhwc=shape_nhwc, - shape_oihw8i32o4i=shape_oihw8i32o4i, + shape_filter=shape_oihw8i32o4i, kernel_size=(kernel, kernel), stride=(stride, stride), padding=(pad, pad, pad, pad), @@ -584,41 +535,7 @@ def test_conv2d( h_split_factor=h_split_factor, ) - # nhwc8h8w32c - if len(output.shape) == 7: - # nhwc8h8w32c -> nhwc - output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( - output.shape[0], - output.shape[1] * output.shape[4], - output.shape[2] * output.shape[5], - output.shape[3] * output.shape[6], - ) - - # nhwhwc - else: - # nhwhwc -> nhwc - output = output.transpose(0, 1, 3, 2, 4, 5).reshape( - output.shape[0], - output.shape[1] * output.shape[3], - output.shape[2] * output.shape[4], - output.shape[5], - ) - - # slice output to match ref_output shape - # e.g. 8x8 spatial 3x3 filter = 6x6 ref output - # but still 8x8 output given the blocked layout - output = output[ - 0 : ref_output.shape[0] : 1, - 0 : ref_output.shape[1] : 1, - 0 : ref_output.shape[2] : 1, - 0 : ref_output.shape[3] : 1, - ] - - if "int" in dtype: - tol = {"atol": 0, "rtol": 0} - elif dtype == "float32": - tol = {"rtol": 1e-4, "atol": 2e-4} - tvm.testing.assert_allclose(output, ref_output, **tol) + verify_conv2d(output, ref_output, dtype) if __name__ == "__main__": diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.md b/tests/python/contrib/test_hexagon/test_conv2d_conv2d.md new file mode 100644 index 000000000000..e42deb65c0c4 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_conv2d_conv2d.md @@ -0,0 +1,986 @@ + + + + + + + + + + + + + + + + + +Hexagon conv2d -> conv2d schedules + +# Baseline conv2d -> conv2d + +This is a baseline 1x1 conv2d -> 1x1 conv2d schedule for Hexagon. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-1-1-float32-llvm]" + +## Parameters + +| Parameter | Value | +| ------------------------ | ----- | +| Batch | 1 | +| Input Size | 64x64 | +| Input Channel | 128 | +| Conv2d #1 Pad | 0 | +| Conv2d #1 Stride | 1 | +| Conv2d #1 Kernel Size | 1 | +| Conv2d #1 Output Channel | 128 | +| Conv2d #2 Stride | 1 | +| Conv2d #2 Kernel Size | 1 | +| Conv2d #2 Output Channel | 128 | +| k_split | 1 | +| h_split | 1 | + +## Constants + +| Constant | Value | +| ------------------ | ----- | +| Conv2d #2 Pad | 0 | +| Conv2d #1 Dilation | 1 | +| Conv2d #2 Dilation | 1 | + +## Shapes and Layouts + +The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. + +| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | +| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | +| Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | +| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | +| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | +| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | +| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | + +## Schedule + +This is the conv2d compute schedule: + +``` + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + + for (ko.outer_1: int32, 0, 4) { + + // filter #1 cache read + + // conv2d #1 + for (wo: int32, 0, 8) { + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + } // end ko.outer_1 + + // filter #2 cache read + + // conv2d #2 + for (wo.c: int32, 0, 8) { + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + + // write back output cache + + } // end ho.outer + } // end ko.outer +``` + +Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. + +``` + for (ko.outer_1: int32, 0, 2) { +``` + +## Cache Usage + +*Input Cache* + +We compute over the WC8h8w32c portion of the input so we need 8 * 4 * 8 * 8 * 32 = 64kb for the input cache. + +``` + allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global; +``` + +*Filter Cache* + +We compute over the IHW8i32o4i portion of each filter so we need 4 * 1 * 1 * 8 * 32 * 4 = 4kb filter cache. + +``` + allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global; +``` + +Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. + +*Output Cache* + +We compute over the WK8h832k portion of the output where `k` denotes the output channel. The output cache is computed for each `ko.outer` which means it should be W * 8h * 8w * 32k = 8 * 8 * 8 * 32 = 16kb. And, in fact, this is the case for a single conv2d case. But, as already noted, for this conv2d -> conv2d case "the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2". This means that the output cache must grow accordingly to K * W * 8h * 8w * 32k = 4 * 8 * 8 * 8 * 32 = 64kb. There is a temporary allocation to store the results of conv2d #1: + +``` + allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global; +``` + +Note that the input cache is reused to store the results of conv2d #2. + +## Assumptions + +* n/a + +## To Do + +* Reuse of the input cache to store the results of conv2d #2 could be problematic for async copy. e.g. + +``` +slice 0: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global +slice 1: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global +``` + +In this case the store from slice 0: cache0 -> store -> global +can potentially block the load in slice 1: global -> load -> cache0 + +StorageRewrite is responsible for planning these caches, we'll need to understand how to avoid this for the async case. + +## Annotated TIR + +``` +primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} + buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c + placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i + placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i + placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc + buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { + allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global; + allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global; + allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global; + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + for (wo: int32, 0, 8) { + for (co: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + packed_input.global[(((((wo*8192) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)placeholder_8[((((((ho.outer*65536) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] + } + } + } + } + } + + // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 + for (ko.outer_1: int32, 0, 4) { + + // filter #1 cache read + for (co: int32, 0, 4) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_7[(((((ko.outer_1*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + + // conv2d #1 + for (wo: int32, 0, 8) { + + // init temp output to zero + for (hi.init: int32, 0, 8) { + for (wi.init: int32, 0, 8) { + for (ki.init: int32, 0, 32) { + temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 + } + } + } + + // compute + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] = + ( + (float32*)temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] + + ( + (float32*)packed_input.global[(((((wo*8192) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * + (float32*)packed_filter.global[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } + } + + // filter #2 cache read + // NOTE: reusing same filter cache + for (co: int32, 0, 4) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_6[(((((ko.outer*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + + // conv2d #2 + for (wo.c: int32, 0, 8) { + + // init output cache to zero + // NOTE: reusing the input cache as the output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + packed_input.global[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // compute + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)temp_output[(((((wo.c*8192) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * + (float32*)packed_filter.global[((((rc.outer_1*1024) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] + ) + ) + } + } + } + } + } + } + + // write back output cache + for (wo_1: int32, 0, 8) { + for (hi_1: int32, 0, 8) { + for (wi_1: int32, 0, 8) { + for (ki_1: int32, 0, 32) { + output_2[((((((ho.outer*65536) + (wo_1*8192)) + (ko.outer*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = + (float32*)packed_input.global[((((wo_1*2048) + (hi_1*256)) + (wi_1*32)) + ki_1)] + } + } + } + } + } + } +} +``` + +# Split on Channel Out and Height + +Uses parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-2-2-float32-llvm]" + +## Parameters + +| Parameter | Value | +| ------------------------ | ----- | +| Batch | 1 | +| Input Size | 64x64 | +| Input Channel | 128 | +| Conv2d #1 Pad | 0 | +| Conv2d #1 Stride | 1 | +| Conv2d #1 Kernel Size | 1 | +| Conv2d #1 Output Channel | 128 | +| Conv2d #2 Stride | 1 | +| Conv2d #2 Kernel Size | 1 | +| Conv2d #2 Output Channel | 128 | +| k_split | 2 ^ | +| h_split | 2 ^ | + +^ Changes from above + +## Constants + +| Constant | Value | +| ------------------ | ----- | +| Conv2d #2 Pad | 0 | +| Conv2d #1 Dilation | 1 | +| Conv2d #2 Dilation | 1 | + +## Shapes and Layouts + +The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. + +| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | +| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | +| Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | +| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | +| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | +| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | +| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | + +## Schedule + +This is the conv2d compute schedule: + +``` + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + + // input cache read + for (ho.inner: int32, 0, 2) { + ... + } + + for (ko.outer_1: int32, 0, 2) { + + // filter #1 cache read + for (ko.inner: int32, 0, 2) { + ... + } + + // conv2d #1 + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + } // end ko.outer_1 + + // filter #2 cache read + for (ko.inner: int32, 0, 2) { + ... + } + + // conv2d #2 + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + + // write back output cache + + } // end ho.outer + } // end ko.outer +``` + +The major change here versus above is the presence of `inner` loops for both channel out `ko` and height `ho` dimensions created from the `k_split` and `h_split` schedule parameters respectively, for example: + + +``` + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { +``` + +The effect of this change is increased cache usage given where the caches are computed in the schedule. Specifically, the input cache is now computed over `ho.inner` and the filter caches are computed over `ko.inner` which will grow the size of the cache. Details below. + +(Same as above) Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. + +``` + for (ko.outer_1: int32, 0, 2) { +``` + +## Cache Usage + +*Input Cache* + +The input cache grows by a factor of `h_split = 2` compared with above: + +``` + allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global; +``` + +*Filter Cache* + +The filter cache grows by a factor of `k_split = 2` compared with above: + +``` + allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global; +``` + +(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. + +*Output Cache* + +The output cache grows by a factor of `k_split = 2` compared with above: + +``` + allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global; +``` + +(Same as above) Note that the input cache is reused to store the results of conv2d #2. + +## Assumptions + +* n/a + +## To Do + +* n/a + +## Annotated TIR + +``` +primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} + buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c + placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i + placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i + placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc + buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { + allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global; + allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global; + allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global; + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + + // input cache read + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)placeholder_8[(((((((ho.outer*131072) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] + } + } + } + } + } + } + + // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 + for (ko.outer_1: int32, 0, 2) { + + // filter #1 cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 4) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_7[((((((ko.outer_1*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + } + + // conv2d #1 + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + + // init temp output to zero + for (hi.init: int32, 0, 8) { + for (wi.init: int32, 0, 8) { + for (ki.init: int32, 0, 32) { + temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 + } + } + } + + // compute + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + ( + (float32*)temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + + ( + (float32*)packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * + (float32*)packed_filter.global[(((((ko.inner*4096) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } + } + } + } + + // filter #2 cache read + // NOTE: reusing same filter cache + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 4) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_6[((((((ko.outer*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + } + + // conv2d #2 + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache to zero + // NOTE: reusing the input cache as the output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // compute + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)temp_output[((((((ho.c.inner*65536) + (wo.c*8192)) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * + (float32*)packed_filter.global[(((((ko.c.inner*4096) + (rc.outer_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] + ) + ) + } + } + } + } + } + } + } + } + + // write back output cache + for (ko.inner_1: int32, 0, 2) { + for (ho.inner_1: int32, 0, 2) { + for (wo_1: int32, 0, 8) { + for (hi_1: int32, 0, 8) { + for (wi_1: int32, 0, 8) { + for (ki_1: int32, 0, 32) { + output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = + (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] + } + } + } + } + } + } + } + } +} +``` + +# 3x3 conv2d -> conv2d (no padding) + +Change from a 1x1 filter to a 3x3 filter. + +## Command + +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-3-128-1-3-128-2-2-float32-llvm]" + +## Parameters + +| Parameter | Value | +| ------------------------ | ----- | +| Batch | 1 | +| Input Size | 64x64 | +| Input Channel | 128 | +| Conv2d #1 Pad | 0 | +| Conv2d #1 Stride | 1 | +| Conv2d #1 Kernel Size | 3 ^ | +| Conv2d #1 Output Channel | 128 | +| Conv2d #2 Stride | 1 | +| Conv2d #2 Kernel Size | 3 ^ | +| Conv2d #2 Output Channel | 128 | +| k_split | 2 | +| h_split | 2 | + +^ Changes from above + +## Constants + +| Constant | Value | +| ------------------ | ----- | +| Conv2d #2 Pad | 0 | +| Conv2d #1 Dilation | 1 | +| Conv2d #2 Dilation | 1 | + +## Shapes and Layouts + +The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. + +| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | +| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | +| Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | +| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | +| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] | +| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 62, 62, 128] | +| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] | +| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 60, 60, 128] | + +## Schedule + +This is the conv2d compute schedule: + +``` + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + + for (ko.outer_1: int32, 0, 2) { + for (ho.outer_1: int32, 0, 2) { + + // input cache read + for (ho.inner: int32, 0, 3) { + if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { + ... + } + } + + // filter #1 cache read + for (ko.inner: int32, 0, 2) { + ... + } + + // conv2d #1 + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + if (((ho.outer_1*2) + ho.inner) < 3) { + if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + } // end ho.outer_1 + } // end ko.outer_1 + + // filter #2 cache read + for (ko.inner: int32, 0, 2) { + ... + } + + // conv2d #2 + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh_1: int32, 0, 3) { + for (rw_1: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + + // write back output cache + + } // end ho.outer + } // end ko.outer +``` + +There are two major changes here: + +1) The first change is the farily obvious presence of the the kernel height `rh` and width `rw` iterators, for example: + +``` + for (rh_1: int32, 0, 3) { + for (rw_1: int32, 0, 3) { +``` + +The effect of this change is to grow the filter cache by the size of the kernel. Details below. + +2) The second change is a bit more tricky. Remember that we want to produce `h_split` (2) "full width" and "full channel depth" slices from each conv2d. Given the 3x3 kernel size there are several changes to the schedule regarding the handling of the height dimension. + +First, notice that in order to produce `h_split` (2) "full width" and "full channel depth" slices for conv2d #1 we will need `h_split + 1` (3) "full width" and "full channel depth" slices of the input. This is because a 3x3 kernel (as opposed to a 1x1 kernel) creates a many-to-one relationship between the spatial coordinates of the input relative to the output. To illustrate, the 3x3 kernel will "fall off the bottom" of the 2nd input slice requiring values from the vertically adjacent 3rd input slice in order to produce the 2nd full output slice. Hence, we have the following input cache read over `h_split + 1` (3) input slices: + +``` + for (ho.inner: int32, 0, 3) { + if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { +``` + +The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + +Second, notice that conv2d #1 must produce sufficient output in the height dimension before conv2d #2 can proceed. This is similar to the requirement that conv2d #1 in regard to the channel out dimension, but also different because we do not require *all* output in the height dimenson only *sufficient* output in the height dimension. How much output in the height dimension is required? The intuitive guess might be `h_split + 1` (3) slices but that is wrong and the reason is that the output spatial coordinates are "shrinking" relative to the input coordinate space due to lack of padding. Hence 2 output slices from conv2d #1 are sufficient as intput to calculate 2 output slices from conv2d #2 and we get the following independent loop over `ho.outer_1` for conv2d #1: + +``` + for (ho.outer_1: int32, 0, 2) { +``` + +There are similar `if` statements in the conv2d compute schedule to prevent computing off the "bottom" of the input and output. + +(Same as above) Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. + +``` + for (ko.outer_1: int32, 0, 2) { +``` + +## Cache Usage + +*Input Cache* + +The input cache grows to hold the vertically adjacent slice: + +``` + allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global; +``` + +*Filter Cache* + +The filter cache grows to hold the 3x3 filter filter: + +``` + allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global; +``` + +(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. + +*Output Cache* + +The output cache scales with the input cache: + +``` + allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global; +``` + +(Same as above) Note that the input cache is reused to store the results of conv2d #2. + +## Assumptions + +* n/a + +## To Do + +* There may be some opportunity to optimized cache reuse in this case as the vertically adjacent input slice from a previous input cache read will be reloaded as in a subsequent input cache read + +## Annotated TIR + +``` +primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () + attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} + buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c + placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i + placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i + placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc + buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { + allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global; + allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global; + allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global; + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 + for (ko.outer_1: int32, 0, 2) { + // NOTE: compute enough height of conv2d #1 before computing conv2d #2 + for (ho.outer_1: int32, 0, 2) { + + // input cache read + for (ho.inner: int32, 0, 3) { + if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)placeholder_8[((((((((ho.outer_1*131072) + (ho.outer*131072)) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] + } + } + } + } + } + } + } + + // filter #1 cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 4) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_7[((((((((ko.outer_1*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + } + } + } + + // conv2d #1 + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + + // init temp output to zero + if (((ho.outer_1*2) + ho.inner) < 3) { + for (hi.init: int32, 0, 8) { + for (wi.init: int32, 0, 8) { + for (ki.init: int32, 0, 32) { + temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 + } + } + } + } + + // compute + if (((ho.outer_1*2) + ho.inner) < 3) { + if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { + for (rc.outer: int32, 0, 4) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + ( + (float32*)temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + + ( + (float32*)packed_input.global[((((((((floordiv((hi + rh), 8)*65536) + (ho.inner*65536)) + (floordiv((wi + rw), 8)*8192)) + (wo*8192)) + (rc.outer*2048)) + (floormod((hi + rh), 8)*256)) + (floormod((wi + rw), 8)*32)) + rc.inner)] * + (float32*)packed_filter.global[(((((((ko.inner*36864) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] + ) + ) + } + } + } + } + } + } + } + } + } + } + } + } + } + } + + // filter #2 cache read + // NOTE: reusing same filter cache + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 4) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (cio: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (cii: int32, 0, 4) { + packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = + (float32*)placeholder_6[((((((((ko.outer*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] + } + } + } + } + } + } + } + + // conv2d #2 + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache to zero + // NOTE: reusing the input cache as the output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } + } + } + + // compute + for (rc.outer_1: int32, 0, 4) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh_1: int32, 0, 3) { + for (rw_1: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner_1: int32, 0, 32) { + packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + ( + (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)temp_output[((((((((floordiv((hi.c + rh_1), 8)*65536) + (ho.c.inner*65536)) + (floordiv((wi.c + rw_1), 8)*8192)) + (wo.c*8192)) + (rc.outer_1*2048)) + (floormod((hi.c + rh_1), 8)*256)) + (floormod((wi.c + rw_1), 8)*32)) + rc.inner_1)] * + (float32*)packed_filter.global[(((((((ko.c.inner*36864) + (rc.outer_1*9216)) + (rh_1*3072)) + (rw_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] + ) + ) + } + } + } + } + } + } + } + } + } + } + + // write back output cache + for (ko.inner_1: int32, 0, 2) { + for (ho.inner_1: int32, 0, 2) { + for (wo_1: int32, 0, 8) { + for (hi_1: int32, 0, 8) { + for (wi_1: int32, 0, 8) { + for (ki_1: int32, 0, 32) { + output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = + (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] + } + } + } + } + } + } + } + } +} +``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py b/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py new file mode 100644 index 000000000000..937efec86189 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import tvm +from tvm import te +from tvm import topi +from tvm.topi import testing + +from .infrastructure import ( + ceildiv, + build_and_run, + get_block_shape, + get_conv2d_nhwc_shape, + get_filter_block_shape, + get_packed_filter_layout, + get_packed_activation_layout, + verify_conv2d, +) + +import numpy as np +import pytest + + +def conv2dconv2d( + shape_input, + pad1, + stride1, + dilation1, + shape_filter1, + pad2, + stride2, + dilation2, + shape_filter2, + k_split_factor, + h_split_factor, + dtype, + storage_scope="global", +): + """ + Conv2d -> Conv2d wherein the input activation is defined by its + logical NHWC layout. The filter is provided in its physical + packed layout (oihw8i32o4i). The input is padded and then packed + into its physical packed layout (nhwc8h8w32c). The resulting + computation is in the same physical packed layout (nhwc8h8w32c). + """ + + # nhwc layout + X = te.placeholder(shape_input, dtype=dtype) + + # oihw8i32o4i layout + filt_packed1 = te.placeholder(shape_filter1, dtype=dtype) + filt_packed2 = te.placeholder(shape_filter2, dtype=dtype) + + # calculate kernel size and output channels + # given oihw8i32o4i filter layout + kernel_size1 = tuple(shape_filter1[2:4]) + out_channels1 = shape_filter1[0] * shape_filter1[5] + + # get the the logical output shape of conv2d #1 + logical_output_shape1 = get_conv2d_nhwc_shape( + shape_input, + kernel_size1, + stride1, + pad1, + dilation1, + out_channels1, + ) + + block_shape = get_block_shape() + block_H, block_W, block_C = block_shape + + # Calculate padded input + N, H, W, C = shape_input + pad_h = (block_H - ((H + pad1[1]) % block_H)) % block_H + pad_w = (block_W - ((W + pad1[3]) % block_W)) % block_W + X_pad = topi.nn.pad( + X, [0, pad1[0], pad1[2], 0], [0, pad_h, pad_w, 0], pad_value=0, name="padded_input" + ) + + # Calculate packed input + packed_shape = get_packed_activation_layout(X_pad.shape, block_shape) + X_packed = te.compute( + packed_shape, + lambda n, ho, wo, co, hi, wi, ci: X_pad[ + n, ho * block_H + hi, wo * block_W + wi, co * block_C + ci + ], + name="packed_input", + ) + + filter_Cio, filter_Ki, filter_Cii = get_filter_block_shape() + filter_Ci = filter_Cio * filter_Cii + + rh = te.reduce_axis((0, kernel_size1[0]), name="rh") + rw = te.reduce_axis((0, kernel_size1[1]), name="rw") + rc = te.reduce_axis((0, C), name="rc") + + def compute(n, ho, wo, ko, hi, wi, ki): + h = ho * block_H + hi + h_contig = h * stride1[0] + rh + h_block_id = h_contig // block_H + h_block_offset = h_contig % block_H + + w = wo * block_W + wi + w_contig = w * stride1[1] + rw + w_block_id = w_contig // block_W + w_block_offset = w_contig % block_W + + c_block_id = rc // block_C + c_block_offset = rc % block_C + + rco = rc // filter_Ci + rcio = (rc % filter_Ci) // filter_Cii + rcii = rc % filter_Cii + + return te.sum( + X_packed[ + n, + h_block_id, + w_block_id, + c_block_id, + h_block_offset, + w_block_offset, + c_block_offset, + ] + * filt_packed1[ko, rco, rh, rw, rcio, ki, rcii], + axis=[rh, rw, rc], + ) + + output_shape1 = get_packed_activation_layout(logical_output_shape1, block_shape) + temp_Y = te.compute(output_shape1, compute, name="temp_output") + + # calculate kernel size and output channels + # given oihw8i32o4i filter layout + kernel_size2 = tuple(shape_filter2[2:4]) + out_channels2 = shape_filter2[0] * shape_filter2[5] + + # get the the logical output shape of conv2d #2 + logical_input_shape2 = logical_output_shape1 + logical_output_shape2 = get_conv2d_nhwc_shape( + logical_input_shape2, + kernel_size2, + stride2, + pad2, + dilation2, + out_channels2, + ) + + rh = te.reduce_axis((0, kernel_size2[0]), name="rh") + rw = te.reduce_axis((0, kernel_size2[1]), name="rw") + rc = te.reduce_axis((0, logical_input_shape2[3]), name="rc") + + def compute2(n, ho, wo, ko, hi, wi, ki): + h = ho * block_H + hi + h_contig = h * stride2[0] + rh + h_block_id = h_contig // block_H + h_block_offset = h_contig % block_H + + w = wo * block_W + wi + w_contig = w * stride2[1] + rw + w_block_id = w_contig // block_W + w_block_offset = w_contig % block_W + + c_block_id = rc // block_C + c_block_offset = rc % block_C + + rco = rc // filter_Ci + rcio = (rc % filter_Ci) // filter_Cii + rcii = rc % filter_Cii + + return te.sum( + temp_Y[ + n, + h_block_id, + w_block_id, + c_block_id, + h_block_offset, + w_block_offset, + c_block_offset, + ] + * filt_packed2[ko, rco, rh, rw, rcio, ki, rcii], + axis=[rh, rw, rc], + ) + + output_shape2 = get_packed_activation_layout(logical_output_shape2, block_shape) + Y = te.compute(output_shape2, compute2, name="output") + s = te.create_schedule(Y.op) + + s[X_pad].compute_inline() + s[X_packed].compute_inline() + + Xl = s.cache_read(X_packed, storage_scope, [temp_Y]) + F1l = s.cache_read(filt_packed1, storage_scope, [temp_Y]) + F2l = s.cache_read(filt_packed2, storage_scope, [Y]) + Yl = s.cache_write(Y, storage_scope) + + n, ho, wo, ko, hi, wi, ki = s[temp_Y].op.axis + rh, rw, rc = s[temp_Y].op.reduce_axis + rco, rci = s[temp_Y].split(rc, factor=block_C) + koo, koi = s[temp_Y].split(ko, factor=k_split_factor) + hoo, hoi = s[temp_Y].split(ho, factor=h_split_factor) + s[temp_Y].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) + s[Xl].compute_at(s[temp_Y], hoo) + s[F1l].compute_at(s[temp_Y], hoo) + + n, ho, wo, ko, hi, wi, ki = s[Y].op.axis + koo, koi = s[Y].split(ko, factor=k_split_factor) + hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) + s[Yl].compute_at(s[Y], hoo) + + n, ho, wo, ko, hi, wi, ki = s[Yl].op.axis + rh, rw, rc = s[Yl].op.reduce_axis + rco, rci = s[Yl].split(rc, factor=block_C) + koo, koi = s[Yl].split(ko, factor=k_split_factor) + hoo, hoi = s[Yl].split(ho, factor=h_split_factor) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) + + s[temp_Y].compute_at(s[Yl], hoo) + s[F2l].compute_at(s[Yl], hoo) + + binds = {} # TODO + return (s, [X, filt_packed1, filt_packed2, Y], binds) + + +class BaseConv2dConv2d: + # input + batch = tvm.testing.parameter(1) + in_size = tvm.testing.parameter(64) + in_channel = tvm.testing.parameter(128) + # conv2d #1 + pad1 = tvm.testing.parameter(0) + stride1 = tvm.testing.parameter(1) + kernel_size1 = tvm.testing.parameter(1, 3) + out_channel1 = tvm.testing.parameter(128) + # conv2d #2 + stride2 = tvm.testing.parameter(1) + kernel_size2 = tvm.testing.parameter(1, 3) + out_channel2 = tvm.testing.parameter(128) + # schedule params + k_split_factor = tvm.testing.parameter(1, 2) + h_split_factor = tvm.testing.parameter(1, 2) + dtype = tvm.testing.parameter("float32") + + +class TestConv2dConv2dPackedFilter(BaseConv2dConv2d): + @tvm.testing.parametrize_targets("llvm") + @pytest.mark.skip("Skip due to being flaky on i386.") + def test_conv2d( + self, + batch, + in_size, + in_channel, + pad1, + stride1, + kernel_size1, + out_channel1, + stride2, + kernel_size2, + out_channel2, + k_split_factor, + h_split_factor, + dtype, + target, + ): + # TODO: no support for padding in conv2d #2 + pad2 = 0 + + # TODO: no support for dilation + dilation1 = 1 + dilation2 = 1 + + shape_input = [batch, in_size, in_size, in_channel] + shape_filter1_oihw = [out_channel1, in_channel, kernel_size1, kernel_size1] + shape_filter1_oihw8i32o4i = get_packed_filter_layout( + out_channel1, in_channel, kernel_size1, kernel_size1 + ) + + shape_filter2_oihw = [out_channel2, out_channel1, kernel_size2, kernel_size2] + shape_filter2_oihw8i32o4i = get_packed_filter_layout( + out_channel2, out_channel1, kernel_size2, kernel_size2 + ) + + inputs = [ + np.random.uniform(0, 255, size=shape_input).astype(dtype), + np.random.uniform(0, 255, size=shape_filter1_oihw8i32o4i).astype(dtype), + np.random.uniform(0, 255, size=shape_filter2_oihw8i32o4i).astype(dtype), + ] + np_filter1 = ( + inputs[1] + .transpose(0, 5, 1, 4, 6, 2, 3) + .reshape(shape_filter1_oihw) + .transpose(2, 3, 1, 0) + ) + np_filter2 = ( + inputs[2] + .transpose(0, 5, 1, 4, 6, 2, 3) + .reshape(shape_filter2_oihw) + .transpose(2, 3, 1, 0) + ) + temp_output = testing.conv2d_nhwc_python(inputs[0], np_filter1, stride1, pad1) + ref_output = testing.conv2d_nhwc_python(temp_output, np_filter2, stride2, pad2) + output = build_and_run( + inputs, + conv2dconv2d, + target, + target, + shape_input=shape_input, + pad1=(pad1, pad1, pad1, pad1), + stride1=(stride1, stride1), + dilation1=(dilation1, dilation1), + shape_filter1=shape_filter1_oihw8i32o4i, + pad2=(pad2, pad2, pad1, pad1), + stride2=(stride2, stride2), + dilation2=(dilation2, dilation2), + shape_filter2=shape_filter2_oihw8i32o4i, + k_split_factor=k_split_factor, + h_split_factor=h_split_factor, + dtype=dtype, + ) + + verify_conv2d(output, ref_output, dtype) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 121edc4b8c60..6f23228be68c 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -47,12 +47,11 @@ def run_onnx(onnx_model, input_data): return res -def run_relay(func, data_tuple): +def run_relay(func, data_tuple, is_dyn=False): target = "llvm" dev = tvm.device("llvm", 0) - relay_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)( - *data_tuple - ) + kind = "graph" if not is_dyn else "vm" + relay_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(*data_tuple) result = [] relay_res = relay_res if isinstance(relay_res, list) else [relay_res] @@ -62,8 +61,8 @@ def run_relay(func, data_tuple): return result -def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0): - relay_results = run_relay(relay_func, indata) +def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0, is_dyn=False): + relay_results = run_relay(relay_func, indata, is_dyn) onnx_results = run_onnx(func_to_onnx(relay_func, test_name), indata) for relay_res, onnx_res in zip(relay_results, onnx_results): @@ -111,7 +110,7 @@ def verify_conv2d( func = relay.Function([x, w], y) data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) - verify_results(func, [data, kernel], "test_conv2d", rtol=1e-5, atol=1e-5) + verify_results(func, [data, kernel], "test_conv2d", rtol=1e-5, atol=1e-5, is_dyn=True) dshape = (1, 32, 18, 18) kshape = (32, 1, 3, 3) @@ -700,6 +699,26 @@ def verify_resize(dshape, outsize, method, coord_trans, rounding_method, dtype=" verify_resize(isize, osize, method=i, coord_trans=j, rounding_method=k) +def test_dyn(): + """Dynamic unit test.""" + + def verify_dyn_bcast(lhs_shape, rhs_shape, dtype): + lhs_dyn_shape = tuple(relay.Any() for i in range(len(lhs_shape))) + rhs_dyn_shape = tuple(relay.Any() for i in range(len(rhs_shape))) + x = relay.var("x", shape=lhs_dyn_shape, dtype=dtype) + y = relay.var("y", shape=rhs_dyn_shape, dtype=dtype) + z = relay.add(x, y) + func = relay.Function([x, y], z) + lhs_data = np.random.uniform(size=lhs_shape).astype(dtype) + rhs_data = np.random.uniform(size=rhs_shape).astype(dtype) + verify_results( + func, [lhs_data, rhs_data], "test_dyn_bcast", rtol=1e-5, atol=1e-5, is_dyn=True + ) + + verify_dyn_bcast((1, 3, 32, 1), (1, 3, 1, 3), "float32") + verify_dyn_bcast((1, 13), (4, 3, 5, 1), "float32") + + if __name__ == "__main__": test_add() test_bias_add() @@ -730,3 +749,4 @@ def verify_resize(dshape, outsize, method, coord_trans, rounding_method, dtype=" test_round() test_cast() test_resize() + test_dyn() diff --git a/tests/python/contrib/test_rpc_server_device.py b/tests/python/contrib/test_rpc_server_device.py new file mode 100644 index 000000000000..f1b8647683ac --- /dev/null +++ b/tests/python/contrib/test_rpc_server_device.py @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""iOS RPC Server tests.""" +# pylint: disable=invalid-name, no-value-for-parameter, missing-function-docstring, import-error +import sys +import multiprocessing +import pytest +import numpy as np + +import tvm.testing +import tvm.relay.testing +from tvm import te +from tvm import rpc +from tvm import relay, auto_scheduler +from tvm.contrib import utils, xcode, graph_executor +from tvm.autotvm.measure import request_remote +from tvm.auto_scheduler.measure_record import load_records +from tvm.auto_scheduler.measure import MeasureErrorNo +from tvm.auto_scheduler.utils import call_func_with_timeout +from tvm.contrib.popen_pool import PopenWorker, StatusKind +from tvm.rpc import tracker, proxy, server_ios_launcher + + +HOST_URL = "0.0.0.0" +HOST_PORT = 9190 +DEVICE_KEY = "ios_mobile_device" + + +TEMPORARY_DIRECTORY = utils.tempdir() +ARCH = "x86_64" +SDK = "iphonesimulator" +DSO_NAME = "lib.dylib" +DTYPE = "float32" + + +np.random.seed(0) + + +ios_rpc_bundle_description_required = pytest.mark.skipif( + not server_ios_launcher.ServerIOSLauncher.is_compatible_environment(), + reason="To run this test, you need to set environment variables required in ServerIOSLauncher.", +) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_teardown_actions(): + """Setup and teardown actions for pytest.""" + + # No setup actions + yield + # Teardown actions: + server_ios_launcher.ServerIOSLauncher.shutdown_booted_devices() + + +def setup_rpc_standalone_configuration(f): + """ + Host -- RPC server + """ + + def wrapper(): + with server_ios_launcher.ServerIOSContextManager( + mode=server_ios_launcher.RPCServerMode.standalone.value, + host=HOST_URL, + port=HOST_PORT, + key=DEVICE_KEY, + ) as ios_server: + f(host=ios_server.host, port=ios_server.port) + + return wrapper + + +def setup_rpc_proxy_configuration(f): + """ + Host -- Proxy -- RPC server + """ + + def wrapper(): + proxy_server = proxy.Proxy(host=HOST_URL, port=HOST_PORT) + with server_ios_launcher.ServerIOSContextManager( + mode=server_ios_launcher.RPCServerMode.proxy.value, + host=proxy_server.host, + port=proxy_server.port, + key=DEVICE_KEY, + ): + f(host=proxy_server.host, port=proxy_server.port) + proxy_server.terminate() + + return wrapper + + +def setup_rpc_tracker_configuration(f): + """ + tracker + / \ + Host -- RPC server + """ + + def wrapper(): + tracker_server = tracker.Tracker(host=HOST_URL, port=HOST_PORT, silent=True) + with server_ios_launcher.ServerIOSContextManager( + mode=server_ios_launcher.RPCServerMode.tracker.value, + host=tracker_server.host, + port=tracker_server.port, + key=DEVICE_KEY, + ): + f(host=tracker_server.host, port=tracker_server.port) + tracker_server.terminate() + + return wrapper + + +def setup_rpc_tracker_via_proxy_configuration(f): + """ + tracker + / \ + Host -- Proxy -- RPC server + """ + + def wrapper(): + tracker_server = tracker.Tracker(host=HOST_URL, port=HOST_PORT, silent=True) + proxy_server_tracker = proxy.Proxy( + host=HOST_URL, port=8888, tracker_addr=(tracker_server.host, tracker_server.port) + ) + with server_ios_launcher.ServerIOSContextManager( + mode=server_ios_launcher.RPCServerMode.proxy.value, + host=proxy_server_tracker.host, + port=proxy_server_tracker.port, + key=DEVICE_KEY, + ): + f(host=tracker_server.host, port=tracker_server.port) + proxy_server_tracker.terminate() + tracker_server.terminate() + + return wrapper + + +def wrapper_for_call_function_with_timeout(timeout, func, args=(), kwargs=None): + """Wrapper for call_func_with_timeout.""" + + def wrapper(*_args, **_kwargs): + """ + This wrapper is needed because the cloudpicle + cannot serialize objects that contain pointers (RPCSession) + """ + func(*_args, **_kwargs) + return StatusKind.COMPLETE + + worker = PopenWorker() + ret = call_func_with_timeout(worker, timeout=timeout, func=wrapper, args=args, kwargs=kwargs) + if isinstance(ret, Exception): + raise ret + return ret + + +def try_create_remote_session(session_factory, args=(), kwargs=None): + """Deadlock-safe RPC Session creation.""" + + try: + successful_attempt = True + results = [] + for _ in range(2): + ret = wrapper_for_call_function_with_timeout( + timeout=10, func=session_factory, args=args, kwargs=kwargs + ) + results.append(ret) + if not np.all(np.array(results) == StatusKind.COMPLETE): + raise ValueError("One or more sessions ended incorrectly.") + except Exception as e: # pylint: disable=broad-except + successful_attempt = False + print(e) + return successful_attempt + + +def ios_create_dylib(output, objects, **kwargs): # pylint: disable=unused-argument + xcode.create_dylib(output, objects, arch=ARCH, sdk=SDK) + + +ios_create_dylib.output_format = "dylib" + + +def export_lib(lib): + """Export lib to temporary directory.""" + + path_dso = TEMPORARY_DIRECTORY.relpath(DSO_NAME) + lib.export_library(path_dso, fcompile=ios_create_dylib) + return path_dso + + +def get_add_relay_module(a_numpy, b_numpy): + """Get simple relay module that add two tensors.""" + + a = relay.var("a", shape=a_numpy.shape, dtype=DTYPE) + b = relay.var("b", shape=b_numpy.shape, dtype=DTYPE) + params = {} + out = tvm.IRModule.from_expr(relay.add(a, b)) + return out, params + + +def get_add_module(target): + """Get simple module that add two tensors.""" + + n = te.var("n") + A = te.placeholder((n,), name="A") + B = te.placeholder((n,), name="B") + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + return tvm.build(s, [A, B, C], target=target, target_host=target, name="simple_add") + + +@pytest.mark.dependency() +@ios_rpc_bundle_description_required +@setup_rpc_standalone_configuration +def test_rpc_standalone(host, port): + status_ok = try_create_remote_session(session_factory=rpc.connect, args=(host, port)) + assert status_ok + + +@pytest.mark.dependency() +@ios_rpc_bundle_description_required +@setup_rpc_proxy_configuration +def test_rpc_proxy(host, port): + status_ok = try_create_remote_session( + session_factory=rpc.connect, args=(host, port, DEVICE_KEY) + ) + assert status_ok + + +@pytest.mark.dependency() +@ios_rpc_bundle_description_required +@setup_rpc_tracker_configuration +def test_rpc_tracker(host, port): + status_ok = try_create_remote_session( + session_factory=request_remote, args=(DEVICE_KEY, host, port) + ) + assert status_ok + + +@pytest.mark.dependency() +@ios_rpc_bundle_description_required +@setup_rpc_tracker_via_proxy_configuration +def test_rpc_tracker_via_proxy(host, port): + status_ok = try_create_remote_session( + session_factory=request_remote, args=(DEVICE_KEY, host, port) + ) + assert status_ok + + +@pytest.mark.dependency(depends=["test_rpc_standalone"]) +@ios_rpc_bundle_description_required +@setup_rpc_standalone_configuration +def test_can_call_remote_function_with_rpc_standalone(host, port): + remote_session = rpc.connect(host, port) + f = remote_session.get_function("runtime.GetFFIString") + assert f("hello") == "hello" + + +@pytest.mark.dependency(depends=["test_rpc_proxy"]) +@ios_rpc_bundle_description_required +@setup_rpc_proxy_configuration +def test_can_call_remote_function_with_rpc_proxy(host, port): + remote_session = rpc.connect(host, port, key=DEVICE_KEY) + f = remote_session.get_function("runtime.GetFFIString") + assert f("hello") == "hello" + + +@pytest.mark.dependency(depends=["test_rpc_tracker"]) +@ios_rpc_bundle_description_required +@setup_rpc_tracker_configuration +def test_can_call_remote_function_with_rpc_tracker(host, port): + remote_session = request_remote(DEVICE_KEY, host, port) + f = remote_session.get_function("runtime.GetFFIString") + assert f("hello") == "hello" + + +@pytest.mark.dependency(depends=["test_rpc_tracker_via_proxy"]) +@ios_rpc_bundle_description_required +@setup_rpc_tracker_via_proxy_configuration +def test_can_call_remote_function_with_rpc_tracker_via_proxy(host, port): + remote_session = request_remote(DEVICE_KEY, host, port) + f = remote_session.get_function("runtime.GetFFIString") + assert f("hello") == "hello" + + +@pytest.mark.dependency(depends=["test_rpc_standalone"]) +@ios_rpc_bundle_description_required +@setup_rpc_standalone_configuration +def test_basic_functionality_of_rpc_session(host, port): + remote_session = rpc.connect(host, port) + device = remote_session.cpu(0) + + target = tvm.target.Target(target=f"llvm -mtriple={ARCH}-apple-darwin") + lib = get_add_module(target) + path_dso = export_lib(lib) + + # Check correct upload + remote_session.upload(path_dso) + + # Check correct download + downloaded_lib = remote_session.download(DSO_NAME) + with open(path_dso, "rb") as source_lib_file: + assert downloaded_lib == bytearray( + source_lib_file.read() + ), "The downloaded module does not match the loaded module" + + # Check correct remote computing + lib = remote_session.load_module(DSO_NAME) + n = 100 + a = tvm.nd.array(np.random.uniform(size=n).astype(DTYPE), device) + b = tvm.nd.array(np.random.uniform(size=n).astype(DTYPE), device) + c = tvm.nd.array(np.zeros(n, dtype=DTYPE), device) + lib(a, b, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) + + # Check correct remove + remote_session.remove(DSO_NAME) + + +@pytest.mark.dependency(depends=["test_rpc_standalone"]) +@pytest.mark.xfail(reason="Not implemented functionality") +@ios_rpc_bundle_description_required +@setup_rpc_standalone_configuration +def test_cleanup_workspace_after_session_end(host, port): + # Arrange + remote_session = rpc.connect(host, port) + target = tvm.target.Target(target=f"llvm -mtriple={ARCH}-apple-darwin") + lib = get_add_module(target) + path_dso = export_lib(lib) + remote_session.upload(path_dso) + + # Act + del remote_session + remote_session = rpc.connect(host, port) + try: + remote_session.download(DSO_NAME) + status_ok = False + except Exception as _: # pylint: disable=broad-except + status_ok = True + + # Assert + assert status_ok, "Workspace not cleared after RPC Session termination." + + +@pytest.mark.dependency(depends=["test_rpc_standalone"]) +@ios_rpc_bundle_description_required +@setup_rpc_standalone_configuration +def test_graph_executor_remote_run(host, port): + remote_session = rpc.connect(host, port) + target = tvm.target.Target(target=f"llvm -mtriple={ARCH}-apple-darwin") + device = remote_session.cpu(0) + + size = 100 + a = np.random.uniform(size=size).astype(DTYPE) + b = np.random.uniform(size=size).astype(DTYPE) + mod, params = get_add_relay_module(a, b) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, target_host=target, params=params) + + path_dso = export_lib(lib) + remote_session.upload(path_dso) + lib = remote_session.load_module(DSO_NAME) + + gen_module = graph_executor.GraphModule(lib["default"](device)) + + # Check set input + gen_module.set_input("a", tvm.nd.array(a)) + gen_module.set_input("b", tvm.nd.array(b)) + tvm.testing.assert_allclose(gen_module.get_input(0).numpy(), a) + tvm.testing.assert_allclose(gen_module.get_input(1).numpy(), b) + + # Check run + gen_module.run() + out = gen_module.get_output(0) + tvm.testing.assert_allclose(out.numpy(), a + b) + + +@pytest.mark.dependency(depends=["test_rpc_tracker"]) +@ios_rpc_bundle_description_required +@setup_rpc_tracker_configuration +def test_check_auto_schedule_tuning(host, port): # pylint: disable=too-many-locals + log_file = TEMPORARY_DIRECTORY.relpath("ios_tuning_stat.log") + target = tvm.target.Target(target=f"llvm -mtriple={ARCH}-apple-darwin") + mod, params = relay.testing.mlp.get_workload(batch_size=4, image_shape=(1, 4, 4)) + + try: + status_ok = True + measure_runner = auto_scheduler.RPCRunner( + DEVICE_KEY, + host, + port, + min_repeat_ms=1, + timeout=10, + n_parallel=multiprocessing.cpu_count(), + ) + builder = auto_scheduler.LocalBuilder(timeout=10, build_func=ios_create_dylib) + tune_option = auto_scheduler.TuningOptions( + builder=builder, + num_measure_trials=2, + num_measures_per_round=1, + runner=measure_runner, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=0, + ) + + tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) + tasks, task_weights = tasks[:2], task_weights[:2] + tuner = auto_scheduler.TaskScheduler(tasks, task_weights) + tuner.tune(tune_option, search_policy="sketch.random") + + # Check tuning log + tuning_statistic = list(load_records(log_file)) + for _, measure_result in tuning_statistic: + if measure_result.error_no != MeasureErrorNo.NO_ERROR: + raise ValueError( + f"Error for MeasureResult. Error code: {measure_result.error_no}," + f" for details see MeasureErrorNO." + ) + + except Exception as e: # pylint: disable=broad-except + status_ok = False + print(e) + + assert status_ok, "Tuning failed, see logs." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index df4234e7e605..0ee5ce3118f2 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1526,4 +1526,7 @@ def test_empty_subgraph(run_module): if __name__ == "__main__": - pytest.main([__file__]) + import sys + + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_maskrcnn_resnet50(run_module) diff --git a/tests/python/contrib/test_vitis_ai/infrastructure.py b/tests/python/contrib/test_vitis_ai/infrastructure.py index e87d4f874630..578ac37da25b 100644 --- a/tests/python/contrib/test_vitis_ai/infrastructure.py +++ b/tests/python/contrib/test_vitis_ai/infrastructure.py @@ -99,7 +99,7 @@ def build_module( ), "Got {} Vitis-AI partitions, expected {}".format( partition_count, vitis_ai_partitions ) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() return relay.build(mod, target, params=params) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 835b2583c725..ca4ab2247bd9 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -126,6 +126,23 @@ def pytorch_resnet18(tmpdir_factory): return model_file_name +@pytest.fixture(scope="session") +def pytorch_mobilenetv2_quantized(tmpdir_factory): + try: + import torch + import torchvision.models as models + except ImportError: + # Not all environments provide Pytorch, so skip if that's the case. + return "" + model = models.quantization.mobilenet_v2(quantize=True) + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "mobilenet_v2_quantized.pth") + # Trace model into torchscript. + traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) + torch.jit.save(traced_cpu, model_file_name) + + return model_file_name + + @pytest.fixture(scope="session") def onnx_resnet50(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 9d44d8f22f41..2ef84d7f1a6f 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -24,7 +24,7 @@ import pytest import tvm -import tvm.testing +from tvm.testing.utils import ethosn_available from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -370,8 +370,11 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): assert os.path.exists(dumps_path) -@tvm.testing.requires_ethosn -def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant): +@pytest.mark.skipif( + not ethosn_available(), + reason="--target=Ethos(TM)-N77 is not available. TVM built with 'USE_ETHOSN OFF'", +) +def test_compile_tflite_module_with_external_codegen_ethos_n77(tflite_mobilenet_v1_1_quant): pytest.importorskip("tflite") tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) tvmc_package = tvmc.compile(tvmc_model, target="ethos-n77, llvm", dump_code="relay") @@ -416,6 +419,26 @@ def test_compile_tflite_module_with_external_codegen_cmsisnn( assert len(c_source_files) == 3 +@pytest.mark.skipif( + not ethosn_available(), + reason="--target=Ethos(TM)-N78 is not available. TVM built with 'USE_ETHOSN OFF'", +) +def test_compile_tflite_module_with_external_codegen_ethos_n78(tflite_mobilenet_v1_1_quant): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + tvmc_package = tvmc.compile( + tvmc_model, target="ethos-n78 -variant=ethos-n78, llvm", dump_code="relay" + ) + dumps_path = tvmc_package.package_path + ".relay" + + # check for output types + assert type(tvmc_package) is TVMCPackage + assert type(tvmc_package.graph) is str + assert type(tvmc_package.lib_path) is str + assert type(tvmc_package.params) is bytearray + assert os.path.exists(dumps_path) + + @pytest.mark.skipif( not vitis_ai_available(), reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", @@ -474,8 +497,8 @@ def test_compile_tflite_module_with_external_codegen_ethosu( # The number of c_source_files depends on the number of fused subgraphs that # get offloaded to the NPU, e.g. conv2d->depthwise_conv2d->conv2d gets offloaded # as a single subgraph if both of these operators are supported by the NPU. - # Currently there are two source files for CPU execution and two offload graphs - assert len(c_source_files) == 4 + # Currently there are two source files for CPU execution and one offload graph + assert len(c_source_files) == 3 @mock.patch("tvm.relay.build") @@ -500,3 +523,9 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock config={"relay.ext.mock.options": {"testopt": "value"}}, disabled_pass=None, ) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/driver/tvmc/test_composite_target.py b/tests/python/driver/tvmc/test_composite_target.py index 0a0b45eeb970..80b4d1be93d5 100644 --- a/tests/python/driver/tvmc/test_composite_target.py +++ b/tests/python/driver/tvmc/test_composite_target.py @@ -34,6 +34,7 @@ def test_get_codegen_names(): names = tvmc.composite_target.get_codegen_names() assert "ethos-n77" in names + assert "ethos-n78" in names assert "vitis-ai" in names assert len(names) > 0 diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index 4d2fb56c5d4e..e742a1e5e4f7 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -204,7 +204,6 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant): tvmc.load(tflite_mobilenet_v1_1_quant, model_format="onnx") -@pytest.mark.skip(reason="https://github.com/apache/tvm/issues/7455") def test_load_model__pth(pytorch_resnet18): # some CI environments wont offer torch, so skip in case it is not present pytest.importorskip("torch") @@ -218,6 +217,21 @@ def test_load_model__pth(pytorch_resnet18): assert "layer1.0.conv1.weight" in tvmc_model.params.keys() +def test_load_quantized_model__pth(pytorch_mobilenetv2_quantized): + # some CI environments wont offer torch, so skip in case it is not present + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + tvmc_model = tvmc.load(pytorch_mobilenetv2_quantized, shape_dict={"input": [1, 3, 224, 224]}) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict + + # checking weights remain quantized and are not float32 + for p in tvmc_model.params.values(): + assert p.dtype in ["int8", "uint8", "int32"] # int32 for bias + + def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): # some CI environments wont offer pytorch, so skip in case it is not present pytest.importorskip("torch") diff --git a/tests/python/driver/tvmc/test_target.py b/tests/python/driver/tvmc/test_target.py index afb099f3add6..865542ee25c1 100644 --- a/tests/python/driver/tvmc/test_target.py +++ b/tests/python/driver/tvmc/test_target.py @@ -118,14 +118,17 @@ def test_parse_multiple_target(): assert "llvm" == targets[1]["name"] -def test_parse_multiple_target_with_opts(): - targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") +def test_parse_hybrid_target(): + """Hybrid Target and external codegen""" + targets = tvmc.common.parse_target( + "cmsis-nn -accelerator_config=ethos-u55-256, llvm -device=arm_cpu --system-lib" + ) assert len(targets) == 2 - assert "ethos-n77" == targets[0]["name"] - assert "myopt" in targets[0]["opts"] - assert "value" == targets[0]["opts"]["myopt"] + assert "cmsis-nn" == targets[0]["name"] + assert not targets[0]["is_tvm_target"] assert "llvm" == targets[1]["name"] + assert targets[1]["is_tvm_target"] def test_parse_quotes_and_separators_on_options(): @@ -141,3 +144,23 @@ def test_parse_quotes_and_separators_on_options(): assert len(targets_double_quote) == 1 assert "+v1.0x,+value" == targets_double_quote[0]["opts"]["option1"] + + +def test_parse_multiple_target_with_opts_ethos_n77(): + targets = tvmc.common.parse_target("ethos-n77 -myopt=value, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "ethos-n77" == targets[0]["name"] + assert "myopt" in targets[0]["opts"] + assert "value" == targets[0]["opts"]["myopt"] + assert "llvm" == targets[1]["name"] + + +def test_parse_multiple_target_with_opts_ethos_n78(): + targets = tvmc.common.parse_target("ethos-n78 -myopt=value, llvm -device=arm_cpu --system-lib") + + assert len(targets) == 2 + assert "ethos-n78" == targets[0]["name"] + assert "myopt" in targets[0]["opts"] + assert "value" == targets[0]["opts"]["myopt"] + assert "llvm" == targets[1]["name"] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index f6942299b751..b592d504fe7f 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -42,6 +42,16 @@ def test_mapping_target_args(): assert reconstruct_target_args(parsed) == {"llvm": {"mcpu": "cortex-m3"}} +def test_skip_target_from_codegen(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, left = parser.parse_known_args( + ["--target=cmsis-nn, c", "--target-cmsis-nn-from_device=1", "--target-c-mcpu=cortex-m55"] + ) + assert left == ["--target-cmsis-nn-from_device=1"] + assert reconstruct_target_args(parsed) == {"c": {"mcpu": "cortex-m55"}} + + def test_target_recombobulation_single(): tvm_target, _ = tvmc.common.target_from_cli("llvm", {"llvm": {"mcpu": "cortex-m3"}}) diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index f4c0cd102340..233977d66066 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -763,6 +763,94 @@ def test_forward_TanH(): _test_tanh(np.random.rand(10).astype(np.float32)) +####################################################################### +# Embed +# ----------- + + +def _test_embed(data, **kwargs): + """One iteration of Embed""" + _test_op(data, L.Embed, "Embed", **kwargs) + + +def test_forward_Embed(): + k = 20 + data = [i for i in range(k)] + np.random.shuffle(data) + # dimension is 1 + data = np.asarray(data) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 2 + data = np.reshape(data, [4, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 3 + data = np.reshape(data, [2, 2, 5]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + # dimension is 4 + data = np.reshape(data, [2, 2, 5, 1]) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=True, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + _test_embed( + data, + num_output=30, + input_dim=k, + bias_term=False, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ) + + ####################################################################### # Mobilenetv2 # ----------- diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 4dfe89fe40e5..114b8f961374 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -604,6 +604,26 @@ def test_forward_nested_layers(self, keras): ) verify_keras_frontend(keras_model) + def test_forward_l2_normalize(self, keras): + data = keras.layers.Input(shape=(16, 12, 8)) + K = keras.backend + l2_funcs = [ + keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=-2)), + keras.layers.Lambda(lambda v: K.l2_normalize(x=v, axis=-1)), + keras.layers.Lambda(lambda v: K.l2_normalize(axis=1, x=v)), + keras.layers.Lambda(lambda v: K.l2_normalize(v, 2)), + keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=3)), + keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=(2, 3))), + keras.layers.Lambda(lambda v: K.l2_normalize(v, (1, 2))), + keras.layers.Lambda(lambda v: K.l2_normalize(v, axis=[-2, -1])), + keras.layers.Lambda(lambda v: K.l2_normalize(v, [-3, -2])), + ] + for l2_func in l2_funcs: + x = l2_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, layout="NCHW") + verify_keras_frontend(keras_model, layout="NHWC") + if __name__ == "__main__": for k in [keras, tf_keras]: @@ -641,3 +661,4 @@ def test_forward_nested_layers(self, keras): sut.test_forward_zero_padding3d(keras=k) sut.test_forward_embedding(keras=k) sut.test_forward_repeat_vector(keras=k) + sut.test_forward_l2_normalize(keras=k) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd1c77330986..f8870edcb6d1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1966,6 +1966,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11): verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False) # Split a single value to a single value verify_split([1], [[1]], [1], pass_split=True) + # Test that the default case modifies nothing when split list has length one + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [2], 1) + verify_split([[1.0, 2.0]], [[1.0, 2.0]], [1], 0) @tvm.testing.parametrize_targets diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index e3d1fc9daf2b..e427d6f563f9 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -125,6 +125,33 @@ def add_subtract3(inputs1, inputs2): verify_model(add_subtract3, [input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_addmm(): + class Addmm(nn.Layer): + def __init__(self, alpha=1.0, beta=1.0): + super(Addmm, self).__init__() + self.alpha = alpha + self.beta = beta + + @paddle.jit.to_static + def forward(self, inputs, x, y): + return paddle.addmm(inputs, x, y, self.alpha, self.beta) + + input_shapes = [[10, 10], [1, 1], [7, 1]] + x_shapes = [[10, 3], [5, 6], [7, 7]] + y_shapes = [[3, 10], [6, 2], [7, 3]] + input_shapes = [[10, 10]] + x_shapes = [[10, 3]] + y_shapes = [[3, 10]] + + for i in range(len(input_shapes)): + input_data = paddle.rand(input_shapes[i], dtype="float32") + x_data = paddle.rand(x_shapes[i], dtype="float32") + y_data = paddle.rand(y_shapes[i], dtype="float32") + verify_model(Addmm(), input_data=[input_data, x_data, y_data]) + verify_model(Addmm(0.5, 0.3), input_data=[input_data, x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_arg_max_min(): class ArgMax(nn.Layer): @@ -279,6 +306,24 @@ def forward(self, input_data): verify_model(BatchNorm3D(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_bmm(): + class Bmm(nn.Layer): + def __init__(self): + super(Bmm, self).__init__() + + @paddle.jit.to_static + def forward(self, x, y): + return paddle.bmm(x, y) + + x_shapes = [[10, 3, 4], [5, 6, 2], [1, 7, 7]] + y_shapes = [[10, 4, 5], [5, 2, 7], [1, 7, 3]] + for i in range(len(x_shapes)): + x_data = paddle.rand(x_shapes[i], dtype="float32") + y_data = paddle.rand(y_shapes[i], dtype="float32") + verify_model(Bmm(), input_data=[x_data, y_data]) + + @tvm.testing.uses_gpu def test_forward_cast(): @paddle.jit.to_static @@ -382,31 +427,37 @@ def cusum3(inputs): @tvm.testing.uses_gpu def test_forward_conv(): - conv2d_input_shape = [1, 3, 10, 10] - class Conv2D1(nn.Layer): - def __init__(self): + def __init__(self, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros"): super(Conv2D1, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, bias_attr=True) + self.conv = nn.Conv2D( + 3, + 6, + 3, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) self.softmax = nn.Softmax() @paddle.jit.to_static def forward(self, inputs): return self.softmax(self.conv(inputs)) - class Conv2D2(nn.Layer): - def __init__(self): - super(Conv2D2, self).__init__() - self.conv = nn.Conv2D(3, 6, 7, groups=3, bias_attr=False) - self.softmax = nn.Softmax() - - @paddle.jit.to_static - def forward(self, inputs): - return self.softmax(self.conv(inputs)) + input_shapes = [[1, 3, 10, 10], [1, 3, 12, 12]] - conv2d_input_data = paddle.rand(conv2d_input_shape, dtype="float32") - verify_model(Conv2D1(), input_data=conv2d_input_data) - verify_model(Conv2D2(), input_data=conv2d_input_data) + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Conv2D1(), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="VALID", dilation=3), input_data=input_data) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=3), input_data=input_data) + verify_model( + Conv2D1(stride=2, padding=3, dilation=3, padding_mode="replicate"), + input_data=input_data, + ) + verify_model(Conv2D1(stride=2, padding="SAME", dilation=2, groups=3), input_data=input_data) @tvm.testing.uses_gpu @@ -455,13 +506,25 @@ def forward(self, input1, input2): api_list = [ "equal", + "floor_divide", + "greater_equal", + "greater_than", + "less_equal", + "less_than", + "maximum", + "minimum", + "pow", ] x_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] y_shapes = [[1], [8, 20], [4, 1, 1], [2, 3, 8, 8], [2, 3, 3, 9, 1]] for x_shape, y_shape in zip(x_shapes, y_shapes): - x_data = paddle.randint(1, 1000, x_shape, dtype="int32") - y_data = paddle.randint(1, 1000, y_shape, dtype="int32") + x_data = paddle.randint(1, 10, x_shape, dtype="int32") + y_data = paddle.randint(1, 10, y_shape, dtype="int32") for api_name in api_list: + if api_name == "pow": + # only support float for pow + x_data = x_data.astype("float32") + y_data = y_data.astype("float32") verify_model(ElemwiseAPI(api_name), [x_data, y_data]) @@ -522,6 +585,118 @@ def forward(self, x, y): verify_model(ExpandAs(), [x_data, y_data]) +@tvm.testing.uses_gpu +def test_forward_flatten(): + class Flatten(nn.Layer): + def __init__(self, start_axis=0, stop_axis=-1): + super(Flatten, self).__init__() + self.start_axis = start_axis + self.stop_axis = stop_axis + + @paddle.jit.to_static + def forward(self, x): + return paddle.flatten(x, start_axis=self.start_axis, stop_axis=self.stop_axis) + + input_data = paddle.rand([2, 3, 4, 5, 2], dtype="float32") + verify_model(Flatten(), input_data=input_data) + verify_model(Flatten(2), input_data=input_data) + verify_model(Flatten(2, -2), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_gather(): + class Gather(nn.Layer): + def __init__(self, axis=None): + super(Gather, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, x, index): + return paddle.gather(x, index, axis=self.axis) + + x_shapes = [[20, 10], [10, 10, 8]] + index = paddle.to_tensor(np.array([1, 3, 5]).astype("int64")) + for x_shape in x_shapes: + x_data = paddle.rand(x_shape, dtype="float32") + verify_model(Gather(), [x_data, index]) + verify_model(Gather(axis=0), [x_data, index]) + verify_model(Gather(axis=1), [x_data, index]) + + +@tvm.testing.uses_gpu +def test_forward_gather_nd(): + class GatherNd(nn.Layer): + @paddle.jit.to_static + def forward(self, x, index): + return paddle.gather_nd(x, index) + + x_shapes = [[20], [8, 8], [4, 5, 6], [3, 4, 3, 5]] + y_shapes = [[2, 1], [2], [1, 2, 3], [3]] + for x_shape, y_shape in zip(x_shapes, y_shapes): + x_data = paddle.rand(x_shape, dtype="float32") + y_data = paddle.randint(low=0, high=3, shape=y_shape, dtype="int64") + verify_model(GatherNd(), [x_data, y_data]) + + +@tvm.testing.uses_gpu +def test_forward_group_norm(): + class GroupNorm(nn.Layer): + def __init__(self, channels, groups): + super(GroupNorm, self).__init__() + self.group_norm = paddle.nn.GroupNorm(num_channels=channels, num_groups=groups) + + def forward(self, inputs): + return self.group_norm(inputs) + + input_shapes = [[1, 4, 6, 6], [2, 2, 4, 7], [2, 8, 1, 1]] + for input_shape in input_shapes: + num_channels = input_shape[1] + input_data = paddle.uniform(input_shape) + verify_model(GroupNorm(num_channels, 1), input_data) + verify_model(GroupNorm(num_channels, 2), input_data) + + +@tvm.testing.uses_gpu +def test_forward_scatter(): + class Scatter(nn.Layer): + def __init__(self, overwrite=True): + super(Scatter, self).__init__() + self.overwrite = overwrite + + @paddle.jit.to_static + def forward(self, x, index, updates): + return paddle.scatter(x, index, updates, overwrite=self.overwrite) + + x_shapes = [[10], [4, 5], [6, 4, 5], [4, 5, 6, 4]] + index_shapes = [[10], [4], [6], [4]] + for x_shape, index_shape in zip(x_shapes, index_shapes): + x_data = paddle.rand(x_shape, dtype="float32") + updates = paddle.rand(x_shape, dtype="float32") + 1.0 + index = paddle.randint(low=0, high=3, shape=index_shape) + verify_model(Scatter(), [x_data, index, updates]) + verify_model(Scatter(False), [x_data, index, updates]) + + +def test_forward_scatter_nd(): + @paddle.jit.to_static + def scatter_nd(index, updates): + shape = [3, 5, 9, 10] + return paddle.scatter_nd(index, updates, shape) + + @paddle.jit.to_static + def scatter_nd_add(x, index, updates): + return paddle.scatter_nd_add(x, index, updates) + + index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64) + index = paddle.to_tensor(index_data) + updates = paddle.rand(shape=[3, 9, 10], dtype="float32") + verify_model(scatter_nd, [index, updates]) + x = paddle.rand(shape=[3, 5, 4, 9, 10], dtype="float32") + updates = paddle.rand(shape=[3, 2, 9, 10], dtype="float32") + index = paddle.randint(0, 3, shape=[3, 2, 3]) + verify_model(scatter_nd_add, [x, index, updates]) + + @tvm.testing.uses_gpu def test_forward_shape_full(): @paddle.jit.to_static @@ -538,6 +713,26 @@ def full2(inputs): verify_model(full2, input_data=[input_data]) +@tvm.testing.uses_gpu +def test_forward_squeeze(): + class Squeeze(nn.Layer): + def __init__(self, axis=None): + super(Squeeze, self).__init__() + self.axis = axis + + @paddle.jit.to_static + def forward(self, inputs): + return paddle.squeeze(inputs, axis=self.axis) + + input_shapes = [[1, 1, 3, 1, 5], [5, 1, 6]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Squeeze(axis=None), input_data=input_data) + verify_model(Squeeze(axis=1), input_data=input_data) + input_data = paddle.rand([1], dtype="float32") + verify_model(Squeeze(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_ones_like(): @paddle.jit.to_static @@ -587,6 +782,77 @@ def hard_swish(inputs): verify_model(hard_swish, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_interpolate(): + class Interpolate(nn.Layer): + def __init__( + self, + mode="nearest", + align_corners=False, + align_mode=0, + data_format="NCHW", + use_scale=False, + use_list=False, + use_const=False, + ): + super(Interpolate, self).__init__() + self.mode = mode + self.align_corners = align_corners + self.align_mode = align_mode + self.data_format = data_format + self.use_scale = use_scale + self.use_list = use_list + self.use_const = use_const + + @paddle.jit.to_static + def forward(self, x): + size = np.array([15, 19]).astype("int32") + scale = np.array([2.0, 1.0]).astype("float32") + if not self.use_list and not self.use_const: + size = paddle.to_tensor(size) + scale = paddle.to_tensor(scale) + elif not self.use_const: + size0 = paddle.to_tensor(size[0:1]) + size = [size0, int(size[1])] + else: + size = size.tolist() + scale = scale.tolist() + if not self.use_scale: + return paddle.nn.functional.interpolate( + x, + size=size, + mode=self.mode, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format, + ) + else: + return paddle.nn.functional.interpolate( + x, + scale_factor=scale, + mode=self.mode, + align_corners=self.align_corners, + align_mode=self.align_mode, + data_format=self.data_format, + ) + + input_data = paddle.rand([1, 2, 8, 12]).astype("float32") + verify_model(Interpolate(), input_data) + verify_model(Interpolate(use_list=True), input_data) + verify_model(Interpolate(use_scale=True), input_data) + verify_model(Interpolate("bilinear", use_scale=True), input_data) + verify_model(Interpolate("bilinear", use_scale=True, align_corners=True), input_data) + verify_model( + Interpolate( + "bilinear", use_scale=True, align_corners=True, align_mode=1, data_format="NHWC" + ), + input_data, + ) + verify_model( + Interpolate("bicubic", use_scale=True, align_corners=True, align_mode=1), input_data + ) + + @tvm.testing.uses_gpu def test_forward_layer_norm(): @paddle.jit.to_static @@ -648,6 +914,22 @@ def forward(self, x, y): verify_model(LogicalAPI("logical_xor"), [x_data, y_data]) +@tvm.testing.uses_gpu +def test_forward_logical_not(): + class LogicalNot(nn.Layer): + def __init__(self): + super(LogicalNot, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return paddle.logical_not(x).astype("int32") + + input_shapes = [[128], [8, 20], [4, 20, 3], [2, 3, 8, 8], [2, 3, 3, 9, 9]] + for input_shape in input_shapes: + input_data = paddle.randint(-2, 2, input_shape).astype("bool") + verify_model(LogicalNot(), input_data) + + @tvm.testing.uses_gpu def test_forward_look_up(): @paddle.jit.to_static @@ -722,24 +1004,141 @@ def forward(self, input1, input2): @tvm.testing.uses_gpu def test_forward_pool2d(): - @paddle.jit.to_static - def pool2d1(inputs): - return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) + class Pool2D1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d(inputs, kernel_size=2, stride=2, padding=0) - @paddle.jit.to_static - def pool2d2(inputs): - return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) + class Pool2D2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.adaptive_avg_pool2d(inputs, output_size=[3, 3]) - @paddle.jit.to_static - def pool2d3(inputs): - return nn.functional.max_pool2d( - inputs, kernel_size=2, stride=2, padding=0, return_mask=True - ) + class Pool2D3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return nn.functional.avg_pool2d( + inputs, + kernel_size=3, + stride=1, + padding=[1, 1], + exclusive=False, + divisor_override=2.5, + ) + + input_shapes = [[1, 2, 8, 8], [1, 3, 10, 10]] + for input_shape in input_shapes: + input_data = paddle.uniform(shape=input_shape, dtype="float32", min=-1, max=1) + verify_model(Pool2D1(), input_data=input_data) + verify_model(Pool2D2(), input_data=input_data) + verify_model(Pool2D3(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_pad1d(): + class Pad1D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCL"): + super(Pad1D, self).__init__() + self.pad1d = paddle.nn.Pad1D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad1d(inputs) + + input_shapes = [[1, 2, 5], [2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad1D(padding=2), input_data=input_data) + verify_model(Pad1D(padding=[1, 2], data_format="NLC"), input_data=input_data) + verify_model(Pad1D(padding=[0, 2], value=0.3), input_data=input_data) + verify_model(Pad1D(padding=[2, 2], mode="reflect"), input_data=input_data) + verify_model(Pad1D(padding=3, mode="replicate"), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_pad2d(): + class Pad2D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCHW"): + super(Pad2D, self).__init__() + self.pad2d = paddle.nn.Pad2D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad2d(inputs) + + input_shapes = [[1, 2, 5, 5], [2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad2D(padding=2), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], data_format="NHWC"), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], value=0.3), input_data=input_data) + verify_model(Pad2D(padding=[1, 2, 0, 2], mode="reflect"), input_data=input_data) + verify_model(Pad2D(padding=3, mode="replicate"), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_pad3d(): + class Pad3D(nn.Layer): + def __init__(self, padding=0, mode="constant", value=0.0, data_format="NCDHW"): + super(Pad3D, self).__init__() + self.pad3d = paddle.nn.Pad3D(padding, mode=mode, value=value, data_format=data_format) + + @paddle.jit.to_static + def forward(self, inputs): + return self.pad3d(inputs) + + input_shapes = [[1, 2, 2, 5, 5], [1, 2, 2, 5, 9]] + for input_shape in input_shapes: + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Pad3D(padding=2), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], data_format="NDHWC"), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], value=0.3), input_data=input_data) + verify_model(Pad3D(padding=[1, 2, 0, 2, 1, 1], mode="reflect"), input_data=input_data) + verify_model(Pad3D(padding=3, mode="replicate"), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_transpose(): + class Transpose(nn.Layer): + def __init__(self, perm): + super(Transpose, self).__init__() + self.perm = perm + + @paddle.jit.to_static + def forward(self, inputs): + inputs = inputs + inputs.size() + return paddle.transpose(inputs, perm=self.perm) + + input_data = paddle.rand([1, 3, 5, 4, 3], dtype="float32") + verify_model(Transpose([0, 1, 2, 3, 4]), input_data=input_data) + verify_model(Transpose([4, 3, 2, 0, 1]), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_reduce(): + class Reduce(nn.Layer): + def __init__(self, op_name, axis=None, keepdim=False): + super(Reduce, self).__init__() + self.op_name = op_name + self.axis = axis + self.keepdim = keepdim - input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) - verify_model(pool2d1, input_data=input_data) - verify_model(pool2d2, input_data=input_data) - # verify_model(pool2d3, input_data=input_data) + @paddle.jit.to_static + def forward(self, inputs): + result = getattr(paddle, self.op_name)(inputs, axis=self.axis, keepdim=self.keepdim) + result = result.astype("float32") + return result + + input_shapes = [[1, 2, 2, 5, 5], [2, 3, 4], [4, 20], [2, 3, 30, 30]] + for input_shape in input_shapes: + input_data = paddle.uniform(min=-3, max=3, shape=input_shape, dtype="float32") + verify_model(Reduce("all"), input_data=input_data.astype("bool")) + verify_model(Reduce("any", 1), input_data=input_data.astype("bool")) + verify_model(Reduce("max", 0, True), input_data=input_data) + verify_model(Reduce("min", 1, True), input_data=input_data) + verify_model(Reduce("prod", 0), input_data=input_data) + verify_model(Reduce("sum", 0, True), input_data=input_data) + verify_model(Reduce("mean", -1, True), input_data=input_data) @tvm.testing.uses_gpu @@ -842,14 +1241,46 @@ def forward(self, inputs): return self.func(inputs) api_list = [ + "abs", + "acos", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", "exp", + "floor", + "hardshrink", + "hardtanh", + "log", + "log2", + "log10", + "reciprocal", "relu", + "relu6", + "round", + "rsqrt", + "selu", + "sigmoid", + "sign", + "sin", + "sinh", + "softplus", + "softsign", + "sqrt", + "square", + "swish", + "tan", "tanh", ] input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]] for input_shape in input_shapes: input_data = paddle.rand(input_shape, dtype="float32") for api_name in api_list: + if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: + # avoid illegal input, all elements should be positive + input_data = paddle.uniform(input_shape, min=0.01, max=0.99) verify_model(MathAPI(api_name), input_data=input_data) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0031f4143fab..5057f0d2b6b8 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -735,13 +735,30 @@ def test_forward_log_sigmoid(): @tvm.testing.uses_gpu -def test_forward_adaptiveavgpool(): +def test_forward_adaptive_avgpool(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.AdaptiveAvgPool2d([1, 1]).eval(), input_data=input_data) verify_model(torch.nn.AdaptiveAvgPool2d([10, 10]).eval(), input_data=input_data) + input_data = torch.rand([1, 3, 10]).float() + verify_model(torch.nn.AdaptiveAvgPool1d([1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveAvgPool1d([5]).eval(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_adaptive_maxpool(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.AdaptiveMaxPool2d([1, 1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveMaxPool2d([10, 10]).eval(), input_data=input_data) + + input_data = torch.rand([1, 3, 10]).float() + verify_model(torch.nn.AdaptiveMaxPool1d([1]).eval(), input_data=input_data) + verify_model(torch.nn.AdaptiveMaxPool1d([5]).eval(), input_data=input_data) + @tvm.testing.uses_gpu def test_forward_maxpool2d(): @@ -3992,5 +4009,16 @@ def test_fn(out_int32=False, right=False): verify_model(test_fn(out_int32=True, right=True), [values, boundaries]) +@tvm.testing.uses_gpu +def test_roll(): + def test_fn(shifts, dims): + return lambda x: torch.roll(x, shifts, dims) + + x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + verify_model(test_fn(1, 0), [x]) + verify_model(test_fn(-1, 0), [x]) + verify_model(test_fn(shifts=(2, 1), dims=(0, 1)), [x]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 754976ca8c13..cb9122bfdc83 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -25,6 +25,7 @@ import pytest import numpy as np import tvm +import tempfile from tvm import te from tvm import relay @@ -1769,16 +1770,6 @@ def _test_unary_elemwise(math_op, data): compare_tflite_with_tvm(data, ["in:0"], [in_data], [out]) -####################################################################### -# Abs -# --- - - -def _test_abs(data): - """One iteration of abs""" - return _test_unary_elemwise(math_ops.abs, data) - - ####################################################################### # Ceil # ---- @@ -1859,26 +1850,6 @@ def _test_tan(data): return _test_unary_elemwise(math_ops.tan, data) -####################################################################### -# Sqrt -# ---- - - -def _test_sqrt(data): - """One iteration of sqrt""" - return _test_unary_elemwise(math_ops.sqrt, data) - - -####################################################################### -# Neg -# --- - - -def _test_neg(data): - """One iteration of neg""" - return _test_unary_elemwise(math_ops.neg, data) - - ####################################################################### # Square # ------ @@ -1901,20 +1872,17 @@ def _test_elu(data): def _test_forward_unary_elemwise(test_op): # functions that need positive input - if test_op.__name__ in {"_test_log", "_test_sqrt"}: + if test_op.__name__ in {"_test_log"}: test_op(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3))) else: test_op(np.random.uniform(-10, 10, (3, 2)).astype(np.float32)) def test_all_unary_elemwise(): - _test_forward_unary_elemwise(_test_abs) _test_forward_unary_elemwise(_test_floor) _test_forward_unary_elemwise(_test_exp) _test_forward_unary_elemwise(_test_log) _test_forward_unary_elemwise(_test_sin) - _test_forward_unary_elemwise(_test_sqrt) - _test_forward_unary_elemwise(_test_neg) _test_forward_unary_elemwise(_test_square) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): @@ -2812,6 +2780,20 @@ def test_forward_pad(): ], mode="SYMMETRIC", ) + _test_pad( + [ + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)), + np.array([[1, 1], [2, 2]], dtype=np.int64), + ], + mode="REFLECT", + ) + _test_pad( + [ + np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)), + np.array([[1, 1], [2, 2]], dtype=np.int64), + ], + mode="SYMMETRIC", + ) _test_pad( [ np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), @@ -3381,6 +3363,149 @@ def test_forward_rsqrt(): _test_rsqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) +####################################################################### +# SQRT +# ---- + + +def _test_sqrt(data, quantized=False): + """One iteration of SQRT""" + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=1, max=6, name="inq_0" + ) + input_range = {"inq_0": (1, 6)} + out = math_ops.sqrt(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out") + compare_tflite_with_tvm( + data, + "inq_0:0", + [inq_data], + [out], + quantized=True, + input_range=input_range, + experimental_new_converter=True, + ) + else: + out = math_ops.sqrt(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) + + +def test_forward_sqrt(): + """SQRT""" + _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32), quantized=False) + _test_sqrt(np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False) + _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8), quantized=True) + _test_sqrt(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) + + +####################################################################### +# NEG +# ---- + + +def _test_neg(data, quantized=False): + """One iteration of NEG""" + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0") + + if quantized: + inq_data = tf.quantization.fake_quant_with_min_max_args( + in_data, min=1, max=6, name="inq_0" + ) + input_range = {"inq_0": (1, 6)} + out = math_ops.neg(inq_data) + out = tf.quantization.fake_quant_with_min_max_args(out, min=1, max=6, name="out") + compare_tflite_with_tvm( + data, + "inq_0:0", + [inq_data], + [out], + quantized=True, + input_range=input_range, + experimental_new_converter=True, + ) + else: + out = math_ops.neg(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) + + +def test_forward_neg(): + """NEG""" + _test_neg(np.arange(-2.0, 4.0, dtype=np.float32), quantized=False) + _test_neg(np.arange(-2.0, 4.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False) + _test_neg(np.arange(1, 240, 40, dtype=np.uint8), quantized=True) + _test_neg(np.arange(1, 240, 40, dtype=np.uint8).reshape((2, 1, 3)), quantized=True) + + +####################################################################### +# ABS +# ---- + + +def _test_abs(data, quantized=False): + """One iteration of ABS""" + if quantized: + + def _create_model(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + op = tf.math.abs(x) + return op + + dtype = "int8" + model = Model() + + # Save the model + export_dir = tempfile.gettempdir() + "/tf_model" + tf.saved_model.save( + model, + export_dir, + signatures=model.tf_function.get_concrete_function( + tf.TensorSpec(data.shape, tf.float32, name="input"), + ), + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + tmp_data = np.random.rand(*tuple(data.shape)) + yield [tmp_data.astype(np.float32) * 2 - 1] + + converter = tf.lite.TFLiteConverter.from_saved_model(export_dir) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_model_quant = _create_model() + tflite_output = run_tflite_graph(tflite_model_quant, data) + in_node = ["serving_default_input_int8"] + tvm_output = run_tvm_graph(tflite_model_quant, data, in_node) + tvm.testing.assert_allclose( + np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2 + ) + else: + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_0") + out = math_ops.abs(in_data) + compare_tflite_with_tvm(data, "in_0:0", [in_data], [out]) + + +def test_forward_abs(): + """ABS""" + _test_abs(np.arange(-3.0, 3.0, dtype=np.float32), quantized=False) + _test_abs(np.arange(-3.0, 3.0, dtype=np.float32).reshape((2, 1, 3)), quantized=False) + _test_abs(np.arange(-128, 127, 45, dtype=np.int8), quantized=True) + + ####################################################################### # ReLu # ---- @@ -4657,6 +4782,9 @@ def test_prevent_tensorflow_dynamic_range(): test_forward_softmax() test_forward_tanh() test_forward_rsqrt() + test_forward_neg() + test_forward_abs() + test_forward_sqrt() test_forward_relu() test_forward_relu6() test_forward_leaky_relu() diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 746f595a4422..278a95f60b6c 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -33,11 +33,13 @@ import tvm from tvm import relay +from tvm import te from tvm.contrib import utils, graph_executor -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler +from tvm.relay.backend.te_compiler import TECompiler from tvm.relay.backend.utils import mangle_module_name from tvm.micro import export_model_library_format - +from tvm.micro.testing import mlf_extract_workspace_size_bytes _LOG = logging.getLogger(__name__) @@ -536,12 +538,6 @@ def create_header_file(tensor_name, npy_data, output_path, data_linkage): header_file.write("};\n\n") -def extract_main_workspace_size_bytes(extract_dir): - with open(os.path.join(extract_dir, "metadata.json")) as json_f: - metadata = json.load(json_f) - return metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"] - - def compile_models( models: Union[List[AOTTestModel], AOTTestModel], interface_api: str, @@ -624,7 +620,7 @@ def run_and_check( t.extractall(base_path) workspace_bytes += model.extra_memory_in_bytes - workspace_bytes += extract_main_workspace_size_bytes(base_path) + workspace_bytes += mlf_extract_workspace_size_bytes(tar_file) for key in model.inputs: sanitized_tensor_name = re.sub(r"\W", "_", key) @@ -721,7 +717,6 @@ def compile_and_run( def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" - compile_engine.get().clear() with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index 8d03ccc5b5f4..553ed84277c6 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -64,11 +64,12 @@ CRT_SRCS = $(shell find $(CRT_ROOT)) CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c)) CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS)) CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c) -CMSIS_NN_SRCS = $(shell find ${CMSIS_PATH}/CMSIS/NN/Source/*/*.c) UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) +CMSIS_NN_LIBS = $(wildcard ${CMSIS_PATH}/CMSIS/NN/build/Source/*/*.a) + ifdef ETHOSU_TEST_ROOT -ETHOSU_ARCHIVE=${build_dir}/ethosu_core_driver/libethosu_core_driver.a +ETHOSU_DRIVER_LIBS = $(wildcard ${DRIVER_PATH}/build/*.a) ETHOSU_INCLUDE=-I$(ETHOSU_TEST_ROOT) endif @@ -93,24 +94,13 @@ ${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS) $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_startup.a) $(abspath $(build_dir))/libcmsis_startup/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_startup.a) -${build_dir}/libcmsis_nn.a: $(CMSIS_NN_SRCS) - $(QUIET)mkdir -p $(abspath $(build_dir)/libcmsis_nn) - $(QUIET)cd $(abspath $(build_dir)/libcmsis_nn) && $(CC) -c $(PKG_CFLAGS) -D${ARM_CPU} $^ - $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcmsis_nn.a) $(abspath $(build_dir))/libcmsis_nn/*.o - $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcmsis_nn.a) - ${build_dir}/libuart.a: $(UART_SRCS) $(QUIET)mkdir -p $(abspath $(build_dir)/libuart) $(QUIET)cd $(abspath $(build_dir)/libuart) && $(CC) -c $(PKG_CFLAGS) $^ $(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath $(build_dir))/libuart/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a) -${build_dir}/ethosu_core_driver/libethosu_core_driver.a: - $(QUIET)mkdir -p $(@D) - $(QUIET)cd $(DRIVER_PATH) && $(CMAKE) -B $(abspath $(build_dir)/ethosu_core_driver) $(DRIVER_CMAKE_FLAGS) - $(QUIET)cd $(abspath $(build_dir)/ethosu_core_driver) && $(MAKE) - -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a $(ETHOSU_ARCHIVE) +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a $(CMSIS_NN_LIBS) $(ETHOSU_DRIVER_LIBS) $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(PKG_CFLAGS) $(ETHOSU_INCLUDE) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) @@ -132,4 +122,4 @@ run: $(build_dir)/aot_test_runner .DEFAULT: aot_test_runner -.PHONY: run \ No newline at end of file +.PHONY: run diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 22583eda4a40..7669d02cd536 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -41,7 +41,7 @@ def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets() tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() @tvm.testing.uses_gpu @@ -251,7 +251,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) @pytest.mark.parametrize( diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8788faf45866..f42f7ad7ca69 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -23,6 +23,7 @@ from tvm import relay, te from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +from tvm.topi.testing import searchsorted_ref from utils import ref_funcs from utils.assert_diagnostic import DiagnosticTesting @@ -2086,5 +2087,35 @@ def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, ax verify_gather((relay.Any(), relay.Any()), (relay.Any(), relay.Any()), (2, 3), (1, 3), 0) +@tvm.testing.uses_gpu +def test_searchsorted(): + def verify_searchsorted( + sorted_sequence_shape, values_shape, sorted_sequence_shape_np, values_shape_np + ): + x = relay.var("x", relay.TensorType(sorted_sequence_shape, "float32")) + y = relay.var("y", relay.TensorType(values_shape, "float32")) + z = relay.searchsorted(x, y) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + x_np = np.sort(np.random.uniform(size=sorted_sequence_shape_np).astype("float32"), axis=-1) + y_np = np.random.uniform(size=values_shape_np).astype("float32") + + ref_res = searchsorted_ref(x_np, y_np, False, "int32") + check_result([x_np, y_np], mod, [ref_res]) + + for shape_np, values_shape_np in zip([(8, 9, 10), (10,), (11,)], [(8, 9, 20), (5,), (8, 9, 7)]): + sorted_sequence_shape = (relay.Any(),) * len(shape_np) + values_shape = (relay.Any(),) * len(values_shape_np) + + verify_searchsorted( + sorted_sequence_shape, + values_shape, + shape_np, + values_shape_np, + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py new file mode 100644 index 000000000000..ebda4ff47cac --- /dev/null +++ b/tests/python/relay/test_executor.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm import TVMError +from tvm.relay.backend import Executor + + +def test_create_executor(): + executor = Executor("aot") + assert str(executor) == "aot" + + +def test_create_executor_with_options(): + executor = Executor("aot", {"interface-api": "c"}) + assert str(executor) == "aot" + assert executor["interface-api"] == "c" + + +def test_create_executor_with_default(): + executor = Executor("graph") + assert not executor["link-params"] + + +def test_attr_check(): + executor = Executor("aot", {"interface-api": "c"}) + assert "woof" not in executor + assert "interface-api" in executor + + +def test_create_executor_not_found(): + with pytest.raises(TVMError, match='Executor "woof" is not defined'): + Executor("woof", {}) + + +def test_create_executor_attr_not_found(): + with pytest.raises(TVMError, match='Attribute "woof" is not available on this Executor'): + Executor("aot", {"woof": "bark"}) + + +def test_create_executor_attr_type_incorrect(): + with pytest.raises( + TVMError, + match='Attribute "interface-api" should have type "runtime.String"' + ' but instead found "IntImm"', + ): + Executor("aot", {"interface-api": True}) + + +def test_list_executors(): + assert "aot" in Executor.list_executors() + + +@pytest.mark.parametrize("executor", [Executor("aot"), "aot"]) +def test_list_executor_options(executor): + aot_options = Executor.list_executor_options(executor) + assert "interface-api" in aot_options + assert aot_options["interface-api"] == "runtime.String" + + +def test_list_executor_options_not_found(): + with pytest.raises(TVMError, match='Executor "woof" is not defined'): + Executor.list_executor_options("woof") diff --git a/tests/python/relay/test_ir_bind.py b/tests/python/relay/test_ir_bind.py index b179096a0528..0ab0122fa798 100644 --- a/tests/python/relay/test_ir_bind.py +++ b/tests/python/relay/test_ir_bind.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """ test bind function.""" +import pytest import tvm from tvm import te from tvm import relay +from tvm import TVMError def test_bind_params(): @@ -34,5 +36,16 @@ def test_bind_params(): assert tvm.ir.structural_equal(zbinded, zexpected) +def test_bind_duplicated_params(): + a = relay.var("a", shape=(1,)) + aa = relay.var("a", shape=(1,)) + s = a + aa + func = relay.Function([a, aa], s) + + with pytest.raises(TVMError): + relay.build_module.bind_params_by_name(func, {"a": [1.0]}) + + if __name__ == "__main__": test_bind_params() + test_bind_duplicated_params() diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 6d2ac21cc7ff..bcd9066b1ba7 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -208,6 +208,17 @@ def test_conv2d_attrs(): check_json_roundtrip(out) +def test_large_grpah(): + # Test large graphs to avoid stack overflow in serialize/deserialize + size = int(1e5) + var = [relay.var("var_" + str(i), shape=(2, 3)) for i in range(size)] + body = var[-1] + for i in range(size, 1, -1): + body = relay.Let(var[i - 1], op.add(var[i - 2], var[i - 2]), body) + func = relay.Function([var[0]], body) + check_json_roundtrip(func) + + if __name__ == "__main__": test_span() test_constant() @@ -222,3 +233,4 @@ def test_conv2d_attrs(): test_tuple_get_item() test_op() test_conv2d_attrs() + test_large_grpah() diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2834bba9248b..21c460fa0371 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -254,7 +254,7 @@ def test_null_attribute(): z = relay.Function([x], y) z = z.with_attr("TestAttribute", None) txt = astext(z) - assert "TestAttribute=(nullptr)" in txt + assert "TestAttribute=None" in txt def test_span(): diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index ca792204c835..c6eb7531f635 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -26,7 +26,7 @@ from tvm import relay, runtime from tvm.contrib import utils from tvm.relay import transform -from tvm.relay.backend import compile_engine +from tvm.relay.backend import te_compiler from tvm.relay.build_module import bind_params_by_name from tvm.relay.op.contrib.register import get_pattern_table @@ -47,7 +47,7 @@ def check_result( return # Run the reference result - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(ref_mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) @@ -61,7 +61,7 @@ def check_result( ref_result = out.numpy() def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -71,7 +71,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref_result, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) rt_mod = tvm.contrib.graph_executor.create(json, lib, device) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index eaddd33678df..754c9d1c4a74 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1422,7 +1422,8 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense( [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1] ) # floats - verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + # default value not specified + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # negative test cases # sparse indices should be ints @@ -1757,7 +1758,7 @@ def verify_func(target, dev, func, data, ref_res): tvm.testing.assert_allclose(op_result.numpy(), ref_result, rtol=1e-5) else: tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) - relay.backend.compile_engine.get().clear() + relay.backend.te_compiler.get().clear() def test_adv_index(target, dev, executor_kind): @@ -1970,7 +1971,8 @@ def calc_numpy_unique(data, is_sorted=False): uniq = uniq[order].astype(data.dtype) inverse = np.array([reverse_order[i] for i in inverse]).astype("int32") counts = counts[order].astype("int32") - index = np.sort(index) # In unsorted case, need to sort the index of first occurence + # In unsorted case, need to sort the index of first occurence + index = np.sort(index) return [ uniq.astype(data.dtype), index.astype("int32"), diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 7b4eb5231a2c..3a5f458d5970 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -22,6 +22,16 @@ from tvm.relay.testing import run_infer_type, create_workload +def annot_func(f): + """Returns f with arg/result device attributes for the argument and result.""" + return relay.op.annotation.function_on_device(f, [tvm.cpu()], tvm.cpu()) + + +def annot_expr(e): + """Returns e wrapped with an on_device annotation.""" + return relay.op.annotation.on_device(e, tvm.cpu(), is_fixed=True) + + def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -75,7 +85,35 @@ def expected(): with tvm.target.Target("cuda"): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) + + +def test_fold_const_with_on_device(): + """Make sure on_device annotations don't get in the way of constant folding""" + c_data = np.array([1, 2, 3]).astype("float32") + t = relay.TensorType([1, 2, 3], "float32") + + def before(): + c = relay.const(c_data) + x = relay.var("x", t) + y = relay.add(c, c) + y = relay.multiply(y, relay.const(2, "float32")) + y = relay.add(x, y) + z = relay.add(y, c) + f = relay.Function([x], z) + return annot_func(f) + + def expected(): + x = relay.var("x", t) + c_folded = (c_data + c_data) * 2 + y = relay.add(x, relay.const(c_folded)) + z = relay.add(y, relay.const(c_data)) + f = relay.Function([x], z) + return annot_func(f) + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_let(): @@ -101,7 +139,37 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) + + +def test_fold_let_with_on_device(): + """Make sure on_device annotations don't get in the way of constant folding, + and inlined constants bring their annotations with them.""" + c_data = np.array(1).astype("float32") + t = relay.TensorType([1], "float32") + + def before(): + sb = relay.ScopeBuilder() + x = relay.var("x", t) + t1 = sb.let("t1", annot_expr(relay.const(c_data))) + t2 = sb.let("t2", annot_expr(relay.add(t1, t1))) + t3 = sb.let("t3", annot_expr(relay.add(t2, x))) + sb.ret(t3) + f = relay.Function([x], sb.get()) + return annot_func(f) + + def expected(): + sb = relay.ScopeBuilder() + x = relay.var("x", t) + c_folded = c_data + c_data + t3 = sb.let("t3", annot_expr(relay.add(annot_expr(relay.const(c_folded)), x))) + sb.ret(t3) + f = relay.Function([x], sb.get()) + return annot_func(f) + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_tuple(): @@ -124,7 +192,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_concat(): @@ -143,7 +211,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_if(): @@ -164,7 +232,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) cond_data = np.array(0).astype("bool") @@ -182,7 +250,7 @@ def expected(): zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_shape_of(): @@ -204,7 +272,7 @@ def expected(dtype): for dtype in ["int32", "float32"]: zz = run_opt_pass(before(dtype), transform.FoldConstant()) zexpected = run_opt_pass(expected(dtype), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_ndarray_size(): @@ -227,7 +295,7 @@ def expected(dtype): for dtype in ["int32", "float32"]: zz = run_opt_pass(before(dtype), transform.FoldConstant()) zexpected = run_opt_pass(expected(dtype), transform.InferType()) - assert tvm.ir.structural_equal(zz, zexpected) + tvm.ir.assert_structural_equal(zz, zexpected) def test_fold_batch_norm(): @@ -272,7 +340,7 @@ def initializer(_, param): mod = remove_bn_pass(mod) expect = run_infer_type(expected()) - assert tvm.ir.structural_equal(mod["main"], expect) + tvm.ir.assert_structural_equal(mod["main"], expect) def test_fold_dropout(): @@ -295,15 +363,11 @@ def before(): with tvm.transform.PassContext(opt_level=3): after_mod = passes(before_mod) - assert tvm.ir.structural_equal(run_infer_type(before_mod["main"]), after_mod["main"]) + tvm.ir.assert_structural_equal(run_infer_type(before_mod["main"]), after_mod["main"]) if __name__ == "__main__": - test_fold_const() - test_fold_let() - test_fold_tuple() - test_fold_concat() - test_fold_shape_of() - test_fold_batch_norm() - test_fold_ndarray_size() - test_fold_dropout() + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index 321c74f4bbd8..58baad2b0e8f 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -59,9 +59,11 @@ def test_pass_timing_instrument(): assert profiles == "" -def test_custom_instrument(): - @pass_instrument - class MyTest: +instrument_definition_type = tvm.testing.parameter("decorator", "subclass") + + +def test_custom_instrument(instrument_definition_type): + class BaseTest: def __init__(self): self.events = [] @@ -77,6 +79,16 @@ def run_before_pass(self, mod, info): def run_after_pass(self, mod, info): self.events.append("run after " + info.name) + if instrument_definition_type == "decorator": + MyTest = pass_instrument(BaseTest) + + elif instrument_definition_type == "subclass": + + class MyTest(BaseTest, tvm.ir.instrument.PassInstrument): + def __init__(self): + BaseTest.__init__(self) + tvm.ir.instrument.PassInstrument.__init__(self) + mod = get_test_model() my_test = MyTest() with tvm.transform.PassContext(instruments=[my_test]): diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 93cd6f791765..90d88169225c 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -22,6 +22,7 @@ import numpy as np import tvm +from tvm.relay.backend import te_compiler import tvm.relay.testing import tvm.relay.op as reg from tvm import relay @@ -29,7 +30,6 @@ from tvm.relay import transform from tvm.relay.testing import byoc from tvm.contrib import utils -from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.op.contrib.register import get_pattern_table @@ -143,7 +143,7 @@ def update_lib(lib): return lib def check_vm_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -157,7 +157,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.numpy(), ref, rtol=tol, atol=tol) def check_graph_executor_result(): - compile_engine.get().clear() + te_compiler.get().clear() with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) @@ -326,6 +326,49 @@ def expected(): check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res) +def test_extern_compiler_sanitized_ops(): + def expected(): + mod = tvm.IRModule() + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + add = x0 + y0 + # Function that uses C compiler + func = relay.Function([x0, y0], add) + func = set_func_attr(func, "unsanitary-name++", "tvmgen_default_unsanitary_name___main_0") + glb_0 = relay.GlobalVar("tvmgen_default_unsanitary_name___main_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [x, y]) + # Function that uses default compiler. Ops are fused in this function. + p0 = relay.var("p0", shape=(8, 8)) + log = relay.log(p0) + exp = relay.exp(p0) + concat = relay.concatenate([log, exp], axis=0) + fused_func = relay.Function([p0], concat) + fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + fused_call = relay.Call(fused_func, [add_call]) + main = relay.Function([x, y], fused_call) + mod["main"] = main + mod = transform.InferType()(mod) + return mod + + x = relay.var("x", shape=(8, 8)) + y = relay.var("y", shape=(8, 8)) + add = x + y + log = relay.log(add) + exp = relay.exp(add) + concat = relay.concatenate([log, exp], axis=0) + f = relay.Function([x, y], concat) + mod = tvm.IRModule() + mod["main"] = f + mod = WhiteListAnnotator(["add", "subtract", "multiply"], "unsanitary-name++")(mod) + mod = transform.PartitionGraph()(mod) + fused_mod = transform.FuseOps(2)(mod) + expected_mod = expected() + assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True) + + def test_extern_ccompiler_multiple_functions(): def expected(): mod = tvm.IRModule() @@ -508,7 +551,7 @@ def test_extern_dnnl_mobilenet(): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **params ) - compile_engine.get().clear() + te_compiler.get().clear() check_result(mod, {"data": i_data}, (1, 1000), ref_res.numpy(), tol=1e-5, params=params) @@ -950,7 +993,7 @@ def test_exec(mod, params, ref_mod, ref_params, out_shape): ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu(0)).evaluate()( i_data, **ref_params ) - compile_engine.get().clear() + te_compiler.get().clear() mod = get_partitoned_mod(mod, params, dnnl_patterns) diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 79ed014c5503..29e271b1c4d7 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -166,7 +166,6 @@ def test_threefry_generate_out_size(): if __name__ == "__main__": - test_threefry_repeatability(tvm.target.Target("llvm"), tvm.device("cpu")) - test_threefry_split(tvm.target.Target("llvm"), tvm.device("cpu")) - test_threefry_sequential_generate(tvm.target.Target("llvm"), tvm.device("cpu")) - test_threefry_sequential_generate_remaining(tvm.target.Target("llvm"), tvm.device("cpu")) + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_relay_te_compiler.py similarity index 93% rename from tests/python/relay/test_backend_compile_engine.py rename to tests/python/relay/test_relay_te_compiler.py index 092cae01f568..f8498ae83648 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_relay_te_compiler.py @@ -21,6 +21,7 @@ from tvm import relay from tvm import autotvm from tvm import topi +from tvm.relay.backend import te_compiler from tvm.relay.testing import run_infer_type from tvm.relay.testing.temp_op_attr import TempOpAttr import tvm.testing @@ -98,7 +99,7 @@ def _get_impls(dshape, wshape): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.get_valid_implementations( + return relay.backend.te_compiler.get_valid_implementations( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -121,7 +122,7 @@ def _select_impl(dshape, wshape, use_autotvm=False): weight = relay.var("wshape", shape=wshape) out = relay.nn.conv2d(data, weight, padding=(1, 1)) out = run_infer_type(out) - return relay.backend.compile_engine.select_implementation( + return relay.backend.te_compiler.select_implementation( relay.op.get("nn.conv2d"), out.attrs, [te.placeholder(dshape), te.placeholder(wshape)], @@ -161,8 +162,8 @@ def _select_impl(dshape, wshape, use_autotvm=False): assert impl.name == "conv2d_1" -def test_compile_engine(): - engine = relay.backend.compile_engine.get() +def test_te_compiler(): + tec = relay.backend.te_compiler.get() def get_func(shape): x = relay.var("x", shape=shape) @@ -173,31 +174,30 @@ def get_func(shape): mod = relay.transform.InferType()(mod) return mod["main"] - z1 = engine.lower(get_func((10,)), "llvm") - z2 = engine.lower(get_func((10,)), "llvm") - z3 = engine.lower(get_func(()), "llvm") + z1 = tec.lower(get_func((10,)), "llvm") + z2 = tec.lower(get_func((10,)), "llvm") + z3 = tec.lower(get_func(()), "llvm") assert z1.same_as(z2) assert not z3.same_as(z1) if tvm.testing.device_enabled("cuda"): - z4 = engine.lower(get_func(()), "cuda") + z4 = tec.lower(get_func(()), "cuda") assert not z3.same_as(z4) # Test JIT target for target in ["llvm"]: dev = tvm.device(target) if tvm.testing.device_enabled(target): - f = engine.jit(get_func((10,)), target) + f = tec.jit(get_func((10,)), target) x = tvm.nd.array(np.ones(10).astype("float32"), device=dev) y = tvm.nd.empty((10,), device=dev) f(x, y) tvm.testing.assert_allclose(y.numpy(), x.numpy() * 3) - engine.dump() -# Note: Once compile engine is removed, we should keep this test so that +# Note: Once the te compiler is removed, we should keep this test so that # we make sure that opt_level=0 passes are being called correctly. def test_compile_placeholder_bypass(): - engine = relay.backend.compile_engine.get() + te_compiler = relay.backend.te_compiler.get() x = relay.var("x", shape=(2, 3)) y = relay.var("y", shape=(2, 3)) z = relay.var("z", shape=(2, 3)) @@ -264,7 +264,7 @@ def test_compile_nhwc_pack(): if __name__ == "__main__": test_get_valid_implementations() test_select_implementation() - test_compile_engine() + test_te_compiler() test_compile_placeholder_bypass() test_compile_injective_with_tuple() test_compile_tuple_dup() diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py new file mode 100644 index 000000000000..d78b822411bc --- /dev/null +++ b/tests/python/relay/test_runtime.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tvm import TVMError +from tvm.relay.backend import Runtime + + +def test_create(): + runtime = Runtime("cpp") + assert str(runtime) == "cpp" + + +def test_create_runtime_with_options(): + runtime = Runtime("c", {"system-lib": True}) + assert str(runtime) == "c" + assert runtime["system-lib"] + + +def test_attr_check(): + runtime = Runtime("c", {"system-lib": True}) + assert "woof" not in runtime + assert "system-lib" in runtime + + +def test_create_runtime_not_found(): + with pytest.raises(TVMError, match='Runtime "woof" is not defined'): + Runtime("woof", {}) + + +def test_create_runtime_attr_not_found(): + with pytest.raises(TVMError, match='Attribute "woof" is not available on this Runtime'): + Runtime("c", {"woof": "bark"}) + + +def test_create_runtime_attr_type_incorrect(): + with pytest.raises( + TVMError, + match='Attribute "system-lib" should have type "IntImm"' + ' but instead found "runtime.String"', + ): + Runtime("c", {"system-lib": "woof"}) + + +def test_list_runtimes(): + assert "c" in Runtime.list_runtimes() + + +@pytest.mark.parametrize("runtime", [Runtime("c"), "c"]) +def test_list_runtime_options(runtime): + aot_options = Runtime.list_runtime_options(runtime) + assert "system-lib" in aot_options + assert aot_options["system-lib"] == "IntImm" + + +def test_list_runtime_options_not_found(): + with pytest.raises(TVMError, match='Runtime "woof" is not defined'): + Runtime.list_runtime_options("woof") diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 8ec41523f9dc..79979747dfd8 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -32,6 +32,8 @@ import tvm.testing from tvm.relay.transform import InferType from tvm.relay.testing import mlp +from tvm.relay.dataflow_pattern import wildcard, is_op +from tvm.relay.backend.vm import VMCompiler def check_result(target, dev, args, expected_result, mod=None): @@ -973,6 +975,91 @@ def test_benchmark_end_to_end_rpc(): assert result.mean > 0 +def test_shape_func_nested_function(): + data_shape = (relay.Any(), 16) + weight_shape = (relay.Any(), 16) + + dense = relay.nn.dense( + relay.var("data", shape=data_shape), relay.var("weight", shape=weight_shape) + ) + mod = tvm.IRModule.from_expr(dense) + + patterns = [("test.dense", is_op("nn.dense")(wildcard(), wildcard()))] + passes = tvm.transform.Sequential( + [ + relay.transform.MergeComposite(patterns), + relay.transform.AnnotateTarget(["test"]), + relay.transform.PartitionGraph(), + ] + ) + + mod = passes(mod) + + compiler = VMCompiler() + compiler.lower(mod, "llvm") + + +@tvm.testing.requires_cuda +def test_storage_size_and_offset_on_cpu(): + """Tests allocations place sizes and offsets on the CPU host even if the rest + of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + + # CPU = device type 1 + # GPU = device type 2 + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], + param_device_types=[2], result_device_type=2) { + add(%a, %a) + } + """ + ) + + exe = relay.vm.compile( + input(), + tvm.target.Target("cuda"), + ) + + # This program needs two constants: + # - The size of the tensor's storage (first arg) to alloc_storage + # - The offset of the tensor within the storage (second arg) to alloc_tensor + # Both should be on the CPU + assert not "on device of type 2" in exe.constants + assert "on device of type 1" in exe.constants + + +@tvm.testing.requires_cuda +def test_reshape_shape_on_cpu(): + """Tests the argument to a reshape places the shape on the CPU host even if the rest + of the computation is on a different device type.""" + # TODO(mbs): Better would be to test ManifestAlloc independently. + + # CPU = device type 1 + # GPU = device type 2 + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32], + param_device_types=[2], result_device_type=2) { + reshape(%x, newshape=[2, 4, 2]) + } + """ + ) + + exe = relay.vm.compile( + input(), + tvm.target.Target("cuda"), + ) + + # The newshape annotation should have been turned into a constant on the CPU. + assert not "on device of type 2" in exe.constants + assert "on device of type 1" in exe.constants + + if __name__ == "__main__": import sys diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_se_scope.py new file mode 100644 index 000000000000..0a9384fa9c04 --- /dev/null +++ b/tests/python/target/test_se_scope.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm + + +def test_make_se_scope_for_device(): + se_scope = tvm.target.make_se_scope(tvm.device("cuda")) + assert se_scope.device_type == 2 + # ie kDLCUDA + assert se_scope.virtual_device_id == 0 + assert se_scope.target is None + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_and_target(): + target = tvm.target.Target("cuda") + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == "" + + +def test_make_se_scope_for_device_target_and_memory_scope(): + target = tvm.target.Target("cuda") + scope = "local" + se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) + assert se_scope.device_type == 2 # ie kDLCUDA + assert se_scope.target == target + assert se_scope.memory_scope == scope + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 9450a937a155..fbf908170938 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -35,9 +35,6 @@ from tvm.topi.utils import get_const_tuple from tvm.topi.testing import conv2d_nchw_python -pytest.importorskip("tvm.micro.testing") -from tvm.micro.testing import check_tune_log - BUILD = True DEBUG = False @@ -222,6 +219,7 @@ def test_platform_timer(): def test_autotune(): """Verify that autotune works with micro.""" import tvm.relay as relay + from tvm.micro.testing import check_tune_log data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32")) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py new file mode 100644 index 000000000000..f508c7d252e1 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from typing import List + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.ir.module import IRModule +from tvm.meta_schedule.integration import ( + ExtractedTask, + MetaScheduleContext, + TaskExtraction, +) +from tvm.meta_schedule.testing import get_network +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +@tvm.script.ir_module +class MockModule: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + for i in T.serial(0, 16): + with T.block("matmul"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): + (task,) = tasks + assert isinstance(task, ExtractedTask) + assert task.task_name == "mock-task" + tvm.ir.assert_structural_equal(task.mod, mod) + (tir_mod,) = task.dispatched + tvm.ir.assert_structural_equal(tir_mod, MockModule) + + +def test_meta_schedule_integration_task_extraction_query(): + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + env = TaskExtraction() + env.query(task_name="mock-task", mod=mod, dispatched=[MockModule]) + _check_mock_task(env.tasks, mod) + + +def test_meta_schedule_integration_current(): + env = TaskExtraction() + with env: + assert MetaScheduleContext.current() == env + + +def test_meta_schedule_integration_no_current(): + assert MetaScheduleContext.current() is None + + +def test_meta_schedule_integration_multiple_current(): + env = TaskExtraction() + with env: + with pytest.raises(ValueError): + with env: + ... + + +def test_meta_schedule_integration_query_inside_with_scope(): + mod, _, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + env = TaskExtraction() + with env: + MetaScheduleContext.query_inside_with_scope( + task_name="mock-task", + mod=mod, + dispatched=[MockModule], + ) + _check_mock_task(env.tasks, mod) + + +def test_meta_schedule_integration_extract_from_resnet(): + mod, params, _, _ = get_network( + name="resnet-18", + batch_size=1, + layout="NHWC", + dtype="float32", + ) + extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params) + assert len(extracted_tasks) == 30 + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 3f7749ca9e2c..49a3f6309183 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -24,9 +24,8 @@ import tvm from tvm.script import tir as T - from tvm.tir.schedule import Schedule -from tvm.meta_schedule.space_generator import ScheduleFn, SpaceGeneratorUnion +from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -86,5 +85,10 @@ def test_meta_schedule_design_space_generator_union(): _check_correct(design_space) +def test_meta_schedule_design_space_generator_NIE(): + with pytest.raises(NotImplementedError): + PySpaceGenerator() + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 4854aeb5f5aa..edff3552d717 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -33,8 +33,7 @@ from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin -from tvm.meta_schedule.utils import structural_hash +from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -224,5 +223,77 @@ def test_meta_schedule_task_scheduler_multiple(): assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total +def test_meta_schedule_task_scheduler_NIE(): + class MyTaskScheduler(PyTaskScheduler): + pass + + with pytest.raises(NotImplementedError): + MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) + + +def test_meta_schedule_task_scheduler_override_next_task_id_only(): + class MyTaskScheduler(PyTaskScheduler): + done = set() + + def next_task_id(self) -> int: + while len(self.done) != len(tasks): + x = random.randint(0, len(tasks) - 1) + task = tasks[x] + if not task.is_stopped: + """Calling base func via following route: + Python side: + PyTaskScheduler does not have `_is_task_running` + Call TaskScheduler's `is_task_running`, which calls ffi + C++ side: + The ffi calls TaskScheduler's `is_task_running` + But it is overridden in PyTaskScheduler + PyTaskScheduler checks if the function is overridden in python + If not, it returns the TaskScheduler's vtable, calling + TaskScheduler::IsTaskRunning + """ + if self._is_task_running(x): + # Same Here + self._join_running_task(x) + return x + else: + self.done.add(x) + return -1 + + num_trials_per_iter = 6 + num_trials_total = 101 + tasks = [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler.tune() + assert len(database) == num_trials_total * len(tasks) + for task in tasks: + assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_micro_project_api.py b/tests/python/unittest/test_micro_project_api.py index e319318656ef..1e511c41d73e 100644 --- a/tests/python/unittest/test_micro_project_api.py +++ b/tests/python/unittest/test_micro_project_api.py @@ -26,45 +26,52 @@ import tvm -pytest.importorskip("tvm.micro") -from tvm.micro import project_api +# Implementing as a fixture so that the tvm.micro import doesn't occur +# until fixture setup time. This is necessary for pytest's collection +# phase to work when USE_MICRO=OFF, while still explicitly listing the +# tests as skipped. +@tvm.testing.fixture +def BaseTestHandler(): + from tvm.micro import project_api + + class BaseTestHandler_Impl(project_api.server.ProjectAPIHandler): + + DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( + platform_name="platform_name", + is_template=True, + model_library_format_path="./model-library-format-path.sh", + project_options=[ + project_api.server.ProjectOption(name="foo", help="Option foo"), + project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), + ], + ) -class BaseTestHandler(project_api.server.ProjectAPIHandler): - - DEFAULT_TEST_SERVER_INFO = project_api.server.ServerInfo( - platform_name="platform_name", - is_template=True, - model_library_format_path="./model-library-format-path.sh", - project_options=[ - project_api.server.ProjectOption(name="foo", help="Option foo"), - project_api.server.ProjectOption(name="bar", choices=["qux"], help="Option bar"), - ], - ) + def server_info_query(self, tvm_version): + return self.DEFAULT_TEST_SERVER_INFO - def server_info_query(self, tvm_version): - return self.DEFAULT_TEST_SERVER_INFO + def generate_project(self, model_library_format_path, crt_path, project_path, options): + assert False, "generate_project is not implemented for this test" - def generate_project(self, model_library_format_path, crt_path, project_path, options): - assert False, "generate_project is not implemented for this test" + def build(self, options): + assert False, "build is not implemented for this test" - def build(self, options): - assert False, "build is not implemented for this test" + def flash(self, options): + assert False, "flash is not implemented for this test" - def flash(self, options): - assert False, "flash is not implemented for this test" + def open_transport(self, options): + assert False, "open_transport is not implemented for this test" - def open_transport(self, options): - assert False, "open_transport is not implemented for this test" + def close_transport(self, options): + assert False, "open_transport is not implemented for this test" - def close_transport(self, options): - assert False, "open_transport is not implemented for this test" + def read_transport(self, n, timeout_sec): + assert False, "read_transport is not implemented for this test" - def read_transport(self, n, timeout_sec): - assert False, "read_transport is not implemented for this test" + def write_transport(self, data, timeout_sec): + assert False, "write_transport is not implemented for this test" - def write_transport(self, data, timeout_sec): - assert False, "write_transport is not implemented for this test" + return BaseTestHandler_Impl class Transport: @@ -100,6 +107,8 @@ def write(self, data): class ClientServerFixture: def __init__(self, handler): + from tvm.micro import project_api + self.handler = handler self.client_to_server = Transport() self.server_to_client = Transport() @@ -121,7 +130,8 @@ def _process_server_request(self): ), "Server failed to process request" -def test_server_info_query(): +@tvm.testing.requires_micro +def test_server_info_query(BaseTestHandler): fixture = ClientServerFixture(BaseTestHandler()) # Examine reply explicitly because these are the defaults for all derivative test cases. @@ -136,7 +146,10 @@ def test_server_info_query(): ] -def test_server_info_query_wrong_tvm_version(): +@tvm.testing.requires_micro +def test_server_info_query_wrong_tvm_version(BaseTestHandler): + from tvm.micro import project_api + def server_info_query(tvm_version): raise project_api.server.UnsupportedTVMVersionError() @@ -148,7 +161,10 @@ def server_info_query(tvm_version): assert "UnsupportedTVMVersionError" in str(exc_info.value) -def test_server_info_query_wrong_protocol_version(): +@tvm.testing.requires_micro +def test_server_info_query_wrong_protocol_version(BaseTestHandler): + from tvm.micro import project_api + ServerInfoProtocol = collections.namedtuple( "ServerInfoProtocol", list(project_api.server.ServerInfo._fields) + ["protocol_version"] ) @@ -166,7 +182,8 @@ def server_info_query(tvm_version): assert "microTVM API Server supports protocol version 0; want 1" in str(exc_info.value) -def test_base_test_handler(): +@tvm.testing.requires_micro +def test_base_test_handler(BaseTestHandler): """All methods should raise AssertionError on BaseTestHandler.""" fixture = ClientServerFixture(BaseTestHandler()) @@ -180,7 +197,8 @@ def test_base_test_handler(): assert (exc_info.exception) == f"{method} is not implemented for this test" -def test_build(): +@tvm.testing.requires_micro +def test_build(BaseTestHandler): with mock.patch.object(BaseTestHandler, "build", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.build(options={"bar": "baz"}) @@ -188,14 +206,18 @@ def test_build(): fixture.handler.build.assert_called_once_with(options={"bar": "baz"}) -def test_flash(): +@tvm.testing.requires_micro +def test_flash(BaseTestHandler): with mock.patch.object(BaseTestHandler, "flash", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.flash(options={"bar": "baz"}) fixture.handler.flash.assert_called_once_with(options={"bar": "baz"}) -def test_open_transport(): +@tvm.testing.requires_micro +def test_open_transport(BaseTestHandler): + from tvm.micro import project_api + timeouts = project_api.server.TransportTimeouts( session_start_retry_timeout_sec=1.0, session_start_timeout_sec=2.0, @@ -210,14 +232,18 @@ def test_open_transport(): fixture.handler.open_transport.assert_called_once_with({"bar": "baz"}) -def test_close_transport(): +@tvm.testing.requires_micro +def test_close_transport(BaseTestHandler): with mock.patch.object(BaseTestHandler, "close_transport", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) fixture.client.close_transport() fixture.handler.close_transport.assert_called_once_with() -def test_read_transport(): +@tvm.testing.requires_micro +def test_read_transport(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object(BaseTestHandler, "read_transport", return_value=b"foo\x1b") as patch: fixture = ClientServerFixture(BaseTestHandler()) assert fixture.client.read_transport(128, timeout_sec=5.0) == {"data": b"foo\x1b"} @@ -239,7 +265,10 @@ def test_read_transport(): assert fixture.handler.read_transport.call_count == 3 -def test_write_transport(): +@tvm.testing.requires_micro +def test_write_transport(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object(BaseTestHandler, "write_transport", return_value=None) as patch: fixture = ClientServerFixture(BaseTestHandler()) assert fixture.client.write_transport(b"foo", timeout_sec=5.0) is None @@ -264,7 +293,10 @@ class ProjectAPITestError(Exception): """An error raised in test.""" -def test_method_raises_error(): +@tvm.testing.requires_micro +def test_method_raises_error(BaseTestHandler): + from tvm.micro import project_api + with mock.patch.object( BaseTestHandler, "close_transport", side_effect=ProjectAPITestError ) as patch: @@ -276,7 +308,10 @@ def test_method_raises_error(): assert "ProjectAPITestError" in str(exc_info.value) -def test_method_not_found(): +@tvm.testing.requires_micro +def test_method_not_found(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) with pytest.raises(project_api.server.JSONRPCError) as exc_info: @@ -285,7 +320,10 @@ def test_method_not_found(): assert exc_info.value.code == project_api.server.ErrorCode.METHOD_NOT_FOUND -def test_extra_param(): +@tvm.testing.requires_micro +def test_extra_param(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # test one with has_preprocssing and one without @@ -304,7 +342,10 @@ def test_extra_param(): assert "open_transport: extra parameters: invalid_param_name" in str(exc_info.value) -def test_missing_param(): +@tvm.testing.requires_micro +def test_missing_param(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # test one with has_preprocssing and one without @@ -323,7 +364,10 @@ def test_missing_param(): assert "open_transport: parameter options not given" in str(exc_info.value) -def test_incorrect_param_type(): +@tvm.testing.requires_micro +def test_incorrect_param_type(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # The error message given at the JSON-RPC server level doesn't make sense when preprocessing is @@ -338,7 +382,10 @@ def test_incorrect_param_type(): ) -def test_invalid_request(): +@tvm.testing.requires_micro +def test_invalid_request(BaseTestHandler): + from tvm.micro import project_api + fixture = ClientServerFixture(BaseTestHandler()) # Invalid JSON does not get a reply. diff --git a/tests/python/unittest/test_micro_transport.py b/tests/python/unittest/test_micro_transport.py index a188e612763f..2fbfada198e3 100644 --- a/tests/python/unittest/test_micro_transport.py +++ b/tests/python/unittest/test_micro_transport.py @@ -26,11 +26,15 @@ import tvm.testing -@tvm.testing.requires_micro -class TransportLoggerTests(unittest.TestCase): +# Implementing as a fixture so that the tvm.micro import doesn't occur +# until fixture setup time. This is necessary for pytest's collection +# phase to work when USE_MICRO=OFF, while still explicitly listing the +# tests as skipped. +@tvm.testing.fixture +def transport(): import tvm.micro - class TestTransport(tvm.micro.transport.Transport): + class MockTransport_Impl(tvm.micro.transport.Transport): def __init__(self): self.exc = None self.to_return = None @@ -62,125 +66,159 @@ def read(self, n, timeout_sec): def write(self, data, timeout_sec): return self._raise_or_return() - def test_transport_logger(self): - """Tests the TransportLogger class.""" - - logger = logging.getLogger("transport_logger_test") - with self.assertLogs(logger) as test_log: - transport = self.TestTransport() - transport_logger = tvm.micro.transport.TransportLogger("foo", transport, logger=logger) - - transport_logger.open() - assert test_log.records[-1].getMessage() == "foo: opening transport" - - ########### read() tests ########## - - # Normal log, single-line data returned. - transport.to_return = b"data" - transport_logger.read(23, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61" - " data" - ) - - # Normal log, multi-line data returned. - transport.to_return = b"data" * 6 - transport_logger.read(23, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: read { 3.00s} 23 B -> [ 24 B]:\n" - "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 64 61 74 61 datadata" - ) - - # Lack of timeout prints. - transport.to_return = b"data" - transport_logger.read(15, None) - assert test_log.records[-1].getMessage() == ( - "foo: read { None } 15 B -> [ 4 B]: 64 61 74 61" - " data" - ) - - # IoTimeoutError includes the timeout value. - transport.exc = tvm.micro.transport.IoTimeoutError() - with self.assertRaises(tvm.micro.transport.IoTimeoutError): - transport_logger.read(23, 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]" - ) - - # Other exceptions are logged by name. - transport.exc = tvm.micro.transport.TransportClosedError() - with self.assertRaises(tvm.micro.transport.TransportClosedError): - transport_logger.read(8, 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: read { 0.00s} 8 B -> [err: TransportClosedError]" - ) - - # KeyboardInterrupt produces no log record. - before_len = len(test_log.records) - transport.exc = KeyboardInterrupt() - with self.assertRaises(KeyboardInterrupt): - transport_logger.read(8, 0.0) - - assert len(test_log.records) == before_len - - ########### write() tests ########## - - # Normal log, single-line data written. - transport.to_return = 3 - transport_logger.write(b"data", 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" - " data" - ) - - # Normal log, multi-line data written. - transport.to_return = 20 - transport_logger.write(b"data" * 6, 3.0) - assert test_log.records[-1].getMessage() == ( - "foo: write { 3.00s} <- [ 24 B]:\n" - "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" - "0010 64 61 74 61 64 61 74 61 datadata" - ) - - # Lack of timeout prints. - transport.to_return = 3 - transport_logger.write(b"data", None) - assert test_log.records[-1].getMessage() == ( - "foo: write { None } <- [ 4 B]: 64 61 74 61" - " data" - ) - - # IoTimeoutError includes the timeout value. - transport.exc = tvm.micro.transport.IoTimeoutError() - with self.assertRaises(tvm.micro.transport.IoTimeoutError): - transport_logger.write(b"data", 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]" - ) - - # Other exceptions are logged by name. - transport.exc = tvm.micro.transport.TransportClosedError() - with self.assertRaises(tvm.micro.transport.TransportClosedError): - transport_logger.write(b"data", 0.0) - - assert test_log.records[-1].getMessage() == ( - "foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]" - ) - - # KeyboardInterrupt produces no log record. - before_len = len(test_log.records) - transport.exc = KeyboardInterrupt() - with self.assertRaises(KeyboardInterrupt): - transport_logger.write(b"data", 0.0) - - assert len(test_log.records) == before_len - - transport_logger.close() - assert test_log.records[-1].getMessage() == "foo: closing transport" + return MockTransport_Impl() + + +@tvm.testing.fixture +def transport_logger(transport): + logger = logging.getLogger("transport_logger_test") + return tvm.micro.transport.TransportLogger("foo", transport, logger=logger) + + +@tvm.testing.fixture +def get_latest_log(caplog): + def inner(): + return caplog.records[-1].getMessage() + + with caplog.at_level(logging.INFO, "transport_logger_test"): + yield inner + + +@tvm.testing.requires_micro +def test_open(transport_logger, get_latest_log): + transport_logger.open() + assert get_latest_log() == "foo: opening transport" + + +@tvm.testing.requires_micro +def test_close(transport_logger, get_latest_log): + transport_logger.close() + assert get_latest_log() == "foo: closing transport" + + +@tvm.testing.requires_micro +def test_read_normal(transport, transport_logger, get_latest_log): + transport.to_return = b"data" + transport_logger.read(23, 3.0) + assert get_latest_log() == ( + "foo: read { 3.00s} 23 B -> [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_read_multiline(transport, transport_logger, get_latest_log): + transport.to_return = b"data" * 6 + transport_logger.read(23, 3.0) + assert get_latest_log() == ( + "foo: read { 3.00s} 23 B -> [ 24 B]:\n" + "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" + "0010 64 61 74 61 64 61 74 61 datadata" + ) + + +@tvm.testing.requires_micro +def test_read_no_timeout_prints(transport, transport_logger, get_latest_log): + transport.to_return = b"data" + transport_logger.read(15, None) + assert get_latest_log() == ( + "foo: read { None } 15 B -> [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_read_io_timeout(transport, transport_logger, get_latest_log): + # IoTimeoutError includes the timeout value. + transport.exc = tvm.micro.transport.IoTimeoutError() + with pytest.raises(tvm.micro.transport.IoTimeoutError): + transport_logger.read(23, 0.0) + + assert get_latest_log() == ("foo: read { 0.00s} 23 B -> [IoTimeoutError 0.00s]") + + +@tvm.testing.requires_micro +def test_read_other_exception(transport, transport_logger, get_latest_log): + # Other exceptions are logged by name. + transport.exc = tvm.micro.transport.TransportClosedError() + with pytest.raises(tvm.micro.transport.TransportClosedError): + transport_logger.read(8, 0.0) + + assert get_latest_log() == ("foo: read { 0.00s} 8 B -> [err: TransportClosedError]") + + +@tvm.testing.requires_micro +def test_read_keyboard_interrupt(transport, transport_logger, get_latest_log): + # KeyboardInterrupt produces no log record. + transport.exc = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + transport_logger.read(8, 0.0) + + with pytest.raises(IndexError): + get_latest_log() + + +@tvm.testing.requires_micro +def test_write_normal(transport, transport_logger, get_latest_log): + transport.to_return = 3 + transport_logger.write(b"data", 3.0) + assert get_latest_log() == ( + "foo: write { 3.00s} <- [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_write_multiline(transport, transport_logger, get_latest_log): + # Normal log, multi-line data written. + transport.to_return = 20 + transport_logger.write(b"data" * 6, 3.0) + assert get_latest_log() == ( + "foo: write { 3.00s} <- [ 24 B]:\n" + "0000 64 61 74 61 64 61 74 61 64 61 74 61 64 61 74 61 datadatadatadata\n" + "0010 64 61 74 61 64 61 74 61 datadata" + ) + + +@tvm.testing.requires_micro +def test_write_no_timeout_prints(transport, transport_logger, get_latest_log): + transport.to_return = 3 + transport_logger.write(b"data", None) + assert get_latest_log() == ( + "foo: write { None } <- [ 4 B]: 64 61 74 61" + " data" + ) + + +@tvm.testing.requires_micro +def test_write_io_timeout(transport, transport_logger, get_latest_log): + # IoTimeoutError includes the timeout value. + transport.exc = tvm.micro.transport.IoTimeoutError() + with pytest.raises(tvm.micro.transport.IoTimeoutError): + transport_logger.write(b"data", 0.0) + + assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [IoTimeoutError 0.00s]") + + +@tvm.testing.requires_micro +def test_write_other_exception(transport, transport_logger, get_latest_log): + # Other exceptions are logged by name. + transport.exc = tvm.micro.transport.TransportClosedError() + with pytest.raises(tvm.micro.transport.TransportClosedError): + transport_logger.write(b"data", 0.0) + + assert get_latest_log() == ("foo: write { 0.00s} <- [ 4 B]: [err: TransportClosedError]") + + +@tvm.testing.requires_micro +def test_write_keyboard_interrupt(transport, transport_logger, get_latest_log): + # KeyboardInterrupt produces no log record. + transport.exc = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + transport_logger.write(b"data", 0.0) + + with pytest.raises(IndexError): + get_latest_log() if __name__ == "__main__": diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index 3e38a526855a..b67142b42358 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -184,8 +184,15 @@ def test_report_serialization(): report = vm.profile(data, func_name="main") report2 = Report.from_json(report.json()) - # equality on reports compares pointers, so we compare the printed results instead. - assert str(report) == str(report2) + # Equality on reports compares pointers, so we compare the printed + # results instead. + + # Use .table() instead of str(), because str() includes aggregate + # and column summations whose values may be impacted by otherwise + # negligible conversion errors. (2 occurrences / 3000 trials) + assert report.table(aggregate=False, col_sums=False) == report2.table( + aggregate=False, col_sums=False + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 56392ec8cccc..2ac2ec9dd9e9 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -142,4 +142,5 @@ def check_erf(dev, n, dtype): if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() + test_opencl_max() test_opencl_erf() diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 1edc5d311759..7b708cbe0c12 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -17,7 +17,6 @@ import random import re -import sys import threading import numpy as np @@ -557,4 +556,6 @@ def do_compute(ins, outs): if __name__ == "__main__": - sys.exit(pytest.main(sys.argv)) + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 4ea35c0a2d6c..e508fbb0f747 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -105,6 +105,31 @@ def opaque_access_func() -> None: ) +@T.prim_func +def access_in_if_then_else_func() -> None: + A = T.alloc_buffer([8]) + B = T.alloc_buffer([8]) + with T.block(): + T.reads([A[0:5]]) + T.writes([B[0:8]]) + for i in T.serial(0, 8): + B[i] = T.if_then_else(i < 5, A[i], 0.0, dtype="float32") + + +@T.prim_func +def access_in_branch_func() -> None: + A = T.alloc_buffer([8]) + B = T.alloc_buffer([8]) + with T.block(): + T.reads([A[0:7]]) + T.writes([B[0:8]]) + for i in T.serial(0, 8): + if i < 5: + B[i] = A[i] + 1.0 + else: + B[i] = A[i - 1] + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -175,8 +200,30 @@ def test_match_buffer(): tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) +def test_access_in_if_then_else_func(): + block = access_in_if_then_else_func.body.block.body.block + alloc_buffers = access_in_if_then_else_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + +def test_access_in_branch_func(): + block = access_in_branch_func.body.block.body.block + alloc_buffers = access_in_branch_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() test_opaque_access() test_match_buffer() + test_access_in_if_then_else_func() + test_access_in_branch_func() diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 9e9563a66a5d..b7d78aad140d 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -343,6 +343,34 @@ def test_vectorize(): assert not valid[0] +@tvm.testing.requires_gpu +def test_vectorize_half(): + N = 1024 + + A = te.placeholder((N, N), name="A", dtype="float16") + B = te.compute((N, N), lambda i, j: A[i, j]) + + s = te.create_schedule([B.op]) + + i, j = s[B].op.axis + + s[B].bind(i, te.thread_axis("blockIdx.x")) + jo, ji = s[B].split(j, factor=8) + s[B].bind(jo, te.thread_axis("threadIdx.x")) + s[B].vectorize(ji) + + for target in ["opencl", "cuda"]: + if not tvm.testing.device_enabled(target): + continue + + valid = [None] + with tvm.transform.PassContext( + config={"tir.add_lower_pass": [(2, get_verify_pass(valid, max_vector_bytes=16))]} + ): + tvm.lower(s, [A, B]) + assert valid[0] + + @tvm.testing.requires_gpu def test_vthread(): N = 1024 @@ -409,5 +437,6 @@ def test_redundant_kernels(): test_multiple_kernels() test_wrong_bind() test_vectorize() + test_vectorize_half() test_vthread() test_redundant_kernels() diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 9075e93b9d45..93876c668913 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -239,6 +239,44 @@ def opaque_block(a: T.handle) -> None: A[i + 1] = A[i + 1] + A[i] +@T.prim_func +def block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.serial(0, 128): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + +@T.prim_func +def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block("outer"): + vi = T.axis.S(128, i) + with T.init(): + for j in T.serial(0, 128): + with T.block("init"): + vj = T.axis.S(128, j) + B[vi, vj] = 0.0 + for k in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block("inner"): + vj, vk = T.axis.remap("SR", [j, k]) + B[vi, vj] = B[vi, vj] + A[vi, vj, vk] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -361,5 +399,13 @@ def test_bind_after_bind(): verify_trace_roundtrip(s, mod=element_wise) +def test_block_inside_init(): + s = tir.Schedule(block_inside_init, debug_mask="all") + (i,) = s.get_loops(s.get_block("outer")) + s.bind(i, "threadIdx.x") + tvm.ir.assert_structural_equal(s.mod["main"], thread_bound_block_inside_init) + verify_trace_roundtrip(s, mod=block_inside_init) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index fbf0a6a5bd78..5d2676e41d1c 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -14,15 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys from collections import defaultdict +import sys import pytest -import tvm + from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip -from tvm.tir.schedule import Trace # pylint: disable=no-member,invalid-name,unused-variable @@ -30,9 +29,9 @@ @T.prim_func def elementwise(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128, 128)) - B = T.match_buffer(b, (128, 128, 128)) - for i, j, k in T.grid(128, 128, 128): + A = T.match_buffer(a, (128, 257, 1470)) + B = T.match_buffer(b, (128, 257, 1470)) + for i, j, k in T.grid(128, 257, 1470): with T.block("B"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 @@ -42,7 +41,7 @@ def elementwise(a: T.handle, b: T.handle) -> None: def test_sample_categorical(): - """Test sample categprical sampling function""" + """Test sample categorical sampling function""" n = 1000 sch = tir.Schedule(elementwise, seed=42, debug_mask="all") counter = defaultdict(int) @@ -87,5 +86,35 @@ def test_sample_categorical_serialize(): assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] +def test_sample_perfect_tile_power_of_two(): + sch = tir.Schedule(elementwise, debug_mask="all") + i, _, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 128 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_prime(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 257 + verify_trace_roundtrip(sch, mod=elementwise) + + +def test_sample_perfect_tile_composite(): + sch = tir.Schedule(elementwise, debug_mask="all") + _, _, i = sch.get_loops(sch.get_block("B")) + factors = sch.sample_perfect_tile(i, n=4) + factors = [sch.get(i) for i in factors] + prod = factors[0] * factors[1] * factors[2] * factors[3] + assert prod == 1470 + verify_trace_roundtrip(sch, mod=elementwise) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 440d0ab67a50..d75bc1461c5e 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -36,13 +36,31 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j in T.grid(128, 128): with T.block("init"): vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) + C[vi, vj] = 0.0 for k in range(0, 128): with T.block("update"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +@T.prim_func +def matmul_relu(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + C = T.alloc_buffer((1024, 1024)) + D = T.match_buffer(d, (1024, 1024)) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j in T.grid(1024, 1024): + with T.block("relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(C[vi, vj], 0.0) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -142,5 +160,44 @@ def test_tir_schedule_remove_rv(): sch.get(block_rv) +def test_get_child_blocks(): + s = tir.Schedule(matmul, debug_mask="all") + init = s.get_block("init") + update = s.get_block("update") + # loop + blocks = s.get_child_blocks(s.get_loops(init)[0]) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + # block + root = s.get_block("root") + blocks = s.get_child_blocks(root) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + + +def test_get_producers(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + block = sch.get_block("relu") + (producer,) = sch.get_producers(block) + assert tvm.ir.structural_equal( + sch.get_sref(producer).stmt, + sch.get_sref(sch.get_block("matmul")).stmt, + ) + verify_trace_roundtrip(sch, mod=matmul_relu) + + +def test_get_consumers(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + block = sch.get_block("matmul") + (consumer,) = sch.get_consumers(block) + assert tvm.ir.structural_equal( + sch.get_sref(consumer).stmt, + sch.get_sref(sch.get_block("relu")).stmt, + ) + verify_trace_roundtrip(sch, mod=matmul_relu) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7d3115428f5a..57c87e5dedf4 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -383,6 +383,127 @@ def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: C[i, j] = B[0, j] * 2.0 +@T.prim_func +def padding_pattern_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (20, 20), "float32") + with T.block(): + B = T.alloc_buffer((20, 20), dtypes="float32") + for i, j in T.grid(16, 16): + with T.block(): + B[i, j] = A[i, j] + for i, j in T.grid(20, 20): + with T.block(): + C[i, j] = T.if_then_else( + 2 <= i and i < 18 and 2 <= j and j < 18, + B[i - 2, j - 2], + 0.0, + dtype="float32", + ) + + +@T.prim_func +def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16], dtype="float32") + C = T.match_buffer(c, [20, 20], dtype="float32") + with T.block(): + B = T.alloc_buffer([16, 16], dtype="float32") + for i, j in T.grid(16, 16): + with T.block(): + B[i, j] = A[i, j] + for i, j in T.grid(20, 20): + with T.block(): + C[i, j] = T.if_then_else( + 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], 0.0, dtype="float32" + ) + + +@T.prim_func +def mem_access_in_branch_func(a: T.handle) -> None: + A = T.match_buffer(a, (224, 224), "float32") + with T.block(): + B1 = T.alloc_buffer((224, 224), dtypes="float32") + B2 = T.alloc_buffer((224, 224), dtypes="float32") + B3 = T.alloc_buffer((224, 224), dtypes="float32") + B4 = T.alloc_buffer((224, 224), dtypes="float32") + for i in range(0, 224): + for j in range(0, 224): + with T.block(): + if i < 112 and j < 112: + B1[i, j] = A[i, j] * 2.0 + else: + B2[i, j] = A[i, j] + 3.0 + for i in range(0, 224): + for j in range(0, 224): + with T.block(): + if i < 112 or j < 112: + B3[i, j] = A[i, j] * 2.0 + else: + B4[i, j] = A[i, j] + 3.0 + + +@T.prim_func +def compacted_mem_access_in_branch_func(a: T.handle) -> None: + A = T.match_buffer(a, [224, 224], dtype="float32") + with T.block(): + B1 = T.alloc_buffer([112, 112], dtype="float32") + B2 = T.alloc_buffer([224, 224], dtype="float32") + B3 = T.alloc_buffer([224, 224], dtype="float32") + B4 = T.alloc_buffer([112, 112], dtype="float32") + for i, j in T.grid(224, 224): + with T.block(): + if i < 112 and j < 112: + B1[i, j] = A[i, j] * 2.0 + else: + B2[i, j] = A[i, j] + 3.0 + for i, j in T.grid(224, 224): + with T.block(): + if i < 112 or j < 112: + B3[i, j] = A[i, j] * 2.0 + else: + B4[i - 112, j - 112] = A[i, j] + 3.0 + + +@T.prim_func +def opaque_access_annotated_func(a: T.handle) -> None: + A = T.match_buffer(a, (1024,), "float32") + with T.block(): + B = T.alloc_buffer((1024,), dtypes="float32") + C = T.alloc_buffer((1024,), dtypes="float32") + for i in range(0, 512): + with T.block(): + # no annotation, opaque access will cover full region + T.reads([]) + T.writes([]) + T.store(B.data, i, "float32", A[i]) + with T.block(): + # treat opaque access only access annotated regions, even if + # they are not compatible with actual buffer accesses. + T.reads([B[i]]) + T.writes([C[i : i + 9]]) + T.store(C.data, i, T.load("float32", B.data, i)) + + +@T.prim_func +def compacted_opaque_access_annotated_func(a: T.handle) -> None: + A = T.match_buffer(a, (1024,), "float32") + with T.block(): + B = T.alloc_buffer((1024,), dtypes="float32") + C = T.alloc_buffer((520,), dtypes="float32") + for i in range(0, 512): + with T.block(): + # no annotation, opaque access will cover full region + T.reads([]) + T.writes([]) + T.store(B.data, i, "float32", A[i]) + with T.block(): + # treat opaque access only access annotated regions, even if + # they are not compatible with actual buffer accesses. + T.reads([B[i]]) + T.writes([C[i : i + 9]]) + T.store(C.data, i, T.load("float32", B.data, i)) + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -428,6 +549,18 @@ def test_storage_align(): _check(storage_align_func, compacted_storage_align_func) +def test_padding_pattern(): + _check(padding_pattern_func, compacted_padding_pattern_func) + + +def test_mem_access_in_branch_func(): + _check(mem_access_in_branch_func, compacted_mem_access_in_branch_func) + + +def test_opaque_access_annotated_func(): + _check(opaque_access_annotated_func, compacted_opaque_access_annotated_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -439,3 +572,6 @@ def test_storage_align(): test_match_buffer() test_storage_align() test_lower_te() + test_padding_pattern() + test_mem_access_in_branch_func() + test_opaque_access_annotated_func() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index ee323a64c50f..6859a5d75b75 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -58,14 +58,14 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): - with T.block() as []: - T.reads(A[i, j]) - T.writes(B[i, j]) + with T.block(): + T.reads([A[i, j]]) + T.writes([B[i, j]]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block() as []: - T.reads(B[i, j]) - T.writes(C[i, j]) + with T.block(): + T.reads([B[i, j]]) + T.writes([C[i, j]]) C[i, j] = B[i, j] * 2.0 diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index cc78b84f9b4e..46d39c034454 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -82,13 +82,14 @@ def test_matmul_ir(A, B, C): # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. + C_local = ib.allocate(C.dtype, (1,), scope="local", name="C_local") C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) C_ptr = ib.buffer_ptr(C) - C_sh[ty, tx] = 0.0 + C_local[0] = 0.0 with ib.for_range(0, n // block, name="i") as i: A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] @@ -96,10 +97,10 @@ def test_matmul_ir(A, B, C): ib.emit(syncthread()) with ib.for_range(0, block, name="k") as k: - C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") - + C_local[0] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") ib.emit(syncthread()) + C_sh[ty, tx] = C_local[0] C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] return ib.get() @@ -113,7 +114,8 @@ def test_matmul_ir(A, B, C): ) s = te.create_schedule(C.op) mod = run_passes(s, [A, B, C]) - expected_alloc_size = block * block * 3 * 4 + # C can be allocated at the start of A, so we only need to allocate 2 block * block memory with dtype = float16 + expected_alloc_size = block * block * 4 verify_single_allocation(mod["main"].body, expected_alloc_size) def check_target(target): @@ -249,8 +251,83 @@ def test_device_ir(A, B, C, D): # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); verify_single_allocation(mod["main"].body) + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C, D], target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), dev) + d = tvm.nd.array(np.zeros((n,), dtype=D.dtype), dev) + fadd(a, b, c, d) + tvm.testing.assert_allclose(d.numpy(), a.numpy() + b.numpy() + c.numpy(), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +def test_dyn_shared_more_dtype(): + """Test vectorized store into dynamic shared memory""" + n = 512 + A = te.placeholder((n,), name="A", dtype="int8") + B = te.placeholder((n,), name="B", dtype="int16") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # i8 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # i16 + C_sh = ib.allocate(C.dtype, (n,), scope="shared.dyn") # i32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + A_sh[tx] = Aptr[tx] + B_sh[tx] = Bptr[tx] + + C_sh[tx] = cast(A_sh[tx], "int32") + cast(B_sh[tx], "int32") + Cptr[tx] = C_sh[tx] + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="int32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body, n * 4) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + if __name__ == "__main__": test_matmul_dyn_shared() test_dyn_shared_vectorized_store() test_dyn_shared_reuse_and_merge() + test_dyn_shared_more_dtype() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index b5620d748d8a..9b95266d3287 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -63,7 +63,8 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - check(2 ** 16, 2 ** 16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow + # i32 + i32 is not promoted to i64 even if overflow + check(2 ** 16, 2 ** 16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2 ** 16, dtype="int64"), const(2 ** 16, dtype="int64"), 32, "int64") @@ -185,7 +186,7 @@ def check(m, n, target_bits, target_dtype): def test_relay_basic(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shapex, shapey, target_bits, target_dtype): x = relay.var("x", shape=shapex) @@ -227,7 +228,7 @@ def check(shapex, shapey, target_bits, target_dtype): def test_relay_take(): - engine = relay.backend.compile_engine.get() + engine = relay.backend.te_compiler.get() def check(shape, index, target_bits, target_dtype): x = relay.var("x", shape=shape) diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py new file mode 100644 index 000000000000..f4adac9cf742 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import tvm.testing + + +@tvm.testing.requires_cuda +def test_split_host_device_func_attr(): + m = te.size_var("m") + l = te.size_var("l") + A = te.placeholder((m, l), name="A") + + A1 = te.compute((m, l), lambda i, j: A[i, j], name="A1") + A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name="A2") + + s = te.create_schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], factor=8) + s[A2].bind(xo, te.thread_axis("blockIdx.x")) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + mod = tvm.lower(s, [A, A2], name="f") + + cuda_target = tvm.target.Target("cuda") + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) + )(mod) + fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] + + assert fdevice.attrs["global_symbol"] == "test_kernel0" + assert fdevice.attrs["calling_conv"].value == 2 + assert fdevice.attrs["target"] == cuda_target + assert fdevice.attrs["tir.is_global_func"].value + + +if __name__ == "__main__": + test_split_host_device_func_attr() diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index 1ce9b0cacd29..6880aabcd2f7 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import pytest +import sys + import tvm from tvm import te from tvm.script import tir as T @@ -35,6 +37,42 @@ def _check_fail(original): @T.prim_func def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i in T.thread_binding(0, 128, "blockIdx.x"): + for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 + for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j1_1 in T.serial(0, 32): + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 + + +@T.prim_func +def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) + + +@T.prim_func +def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: j1_0 = T.env_thread("threadIdx.x") j0_0 = T.env_thread("threadIdx.x") i = T.env_thread("blockIdx.x") @@ -42,158 +80,152 @@ def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) T.launch_thread(i, 128) - with T.launch_thread(j0_0, 4): - for j0_1 in T.serial(0, 32): - T.store( - B.data, - i * 128 + j0_0 * 32 + j0_1, - T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, - True, - ) + T.launch_thread(j0_0, 4) T.launch_thread(j1_0, 4) + + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 for j1_1 in T.serial(0, 32): - T.store( - C.data, - i * 128 + j1_0 * 32 + j1_1, - T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, - True, - ) + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 @T.prim_func -def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: - thread_x = T.env_thread("threadIdx.x") - block_x = T.env_thread("blockIdx.x") +def unified_element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - T.launch_thread(block_x, 128) - with T.launch_thread(thread_x, 4): - for j0_1 in T.serial(0, 32): - T.store( - B.data, - block_x * 128 + thread_x * 32 + j0_1, - T.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, - True, - ) - T.launch_thread(thread_x, 4) - for j1_1 in T.serial(0, 32): - T.store( - C.data, - block_x * 128 + thread_x * 32 + j1_1, - T.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, - True, - ) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) @T.prim_func def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: - i_0 = T.env_thread("vthread.x") - i_1 = T.env_thread("threadIdx.x") - j_0 = T.env_thread("vthread.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) - T.launch_thread(i_0, 2) - T.launch_thread(i_1, 64) - T.launch_thread(j_0, 2) - for j_1 in T.serial(0, 64): - T.store( - B.data, - i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, - T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, - True, - ) + for i_0 in T.thread_binding(0, 2, "vthread.x"): + for i_1 in T.thread_binding(0, 64, "threadIdx.x"): + for j_0 in T.thread_binding(0, 2, "vthread.x"): + for j_1 in T.serial(0, 64): + with T.block(""): + B[i_0 * 64 + i_1, j_0 * 64 + j_1] = A[i_0 * 64 + i_1, j_0 * 64 + j_1] * 2.0 @T.prim_func def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: - vthread_x = T.env_thread("vthread.x") - thread_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) - T.launch_thread(vthread_x, 2) - T.launch_thread(thread_x, 64) - T.launch_thread(vthread_x, 2) - for j_1 in T.serial(0, 64): - T.store( - B.data, - vthread_x * 8256 + thread_x * 128 + j_1, - T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, - True, - ) + for vthread_x in T.thread_binding(0, 2, "vthread.x"): + for threadIdx_x in T.thread_binding(0, 64, "threadIdx.x"): + for j_1 in T.serial(0, 64): + with T.block(""): + B[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] = ( + A[vthread_x * 64 + threadIdx_x, vthread_x * 64 + j_1] * 2.0 + ) @T.prim_func def element_wise_two_thread_x_in_same_kernel_not_equal( a: T.handle, b: T.handle, c: T.handle ) -> None: - i = T.env_thread("blockIdx.x") - j0 = T.env_thread("threadIdx.x") - j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 64]) - T.launch_thread(i, 128) - with T.launch_thread(j0, 128): - T.store(B.data, i * 64 + j0, T.load("float32", A.data, i * 128 + j0) * 2.0, True) - T.launch_thread(j1, 64) - T.store(C.data, i * 64 + j1, T.load("float32", A.data, i * 128 + j1) + 1.0, True) + for i in T.thread_binding(0, 128, "blockIdx.x"): + for j0 in T.thread_binding(0, 128, "threadIdx.x"): + B[i, j0] = A[i, j0] * 2.0 + for j1 in T.thread_binding(0, 64, "threadIdx.x"): + C[i, j1] = A[i, j1] + 1.0 @T.prim_func def element_wise_kernels_with_different_size( a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - i0 = T.env_thread("blockIdx.x") - j0 = T.env_thread("threadIdx.x") - i1 = T.env_thread("blockIdx.x") - j1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) - with T.launch_thread(i0, 128): - T.launch_thread(j0, 128) - T.store(B.data, i0 * 128 + j0, T.load("float32", A.data, i0 * 128 + j0) * 2.0, True) - T.launch_thread(i1, 256) - T.launch_thread(j1, 256) - T.store(D.data, i1 * 256 + j1, T.load("float32", C.data, i1 * 256 + j1) + 1.0, True) + for i0 in T.thread_binding(0, 128, "blockIdx.x"): + for j0 in T.thread_binding(0, 128, "threadIdx.x"): + B[i0, j0] = A[i0, j0] * 2.0 + for i1 in T.thread_binding(0, 256, "blockIdx.x"): + for j1 in T.thread_binding(0, 256, "threadIdx.x"): + D[i1, j1] = C[i1, j1] + 1.0 @T.prim_func def unified_element_wise_kernels_with_different_size( a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - block_x = T.env_thread("blockIdx.x") - thread_x = T.env_thread("threadIdx.x") - block_x_1 = T.env_thread("blockIdx.x") - thread_x_1 = T.env_thread("threadIdx.x") A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [256, 256]) D = T.match_buffer(d, [256, 256]) - with T.launch_thread(block_x, 128): - T.launch_thread(thread_x, 128) - T.store( - B.data, - block_x * 128 + thread_x, - T.load("float32", A.data, block_x * 128 + thread_x) * 2.0, - True, - ) - T.launch_thread(block_x_1, 256) - T.launch_thread(thread_x_1, 256) - T.store( - D.data, - block_x_1 * 256 + thread_x_1, - T.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, - True, - ) + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 128, "threadIdx.x"): + B[blockIdx_x, threadIdx_x] = A[blockIdx_x, threadIdx_x] * 2.0 + for blockIdx_x in T.thread_binding(0, 256, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 256, "threadIdx.x"): + D[blockIdx_x, threadIdx_x] = C[blockIdx_x, threadIdx_x] + 1.0 + + +@T.prim_func +def element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i in T.thread_binding(0, 128, "threadIdx.y"): + for j0_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0 + for j1_0 in T.thread_binding(0, 4, "threadIdx.x"): + for j1_1 in T.serial(0, 32): + with T.block(""): + C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0 + + +@T.prim_func +def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for blockIdx_x in T.thread_binding(0, 128, "threadIdx.y"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) def test_thread_x(): _check(element_wise_thread_x, unified_element_wise_thread_x) +def test_env_thread_x(): + _check(element_wise_env_thread_x, unified_element_wise_env_thread_x) + + def test_vthread_x(): _check(element_wise_vthread_x, unified_element_wise_vthread_x) @@ -208,6 +240,10 @@ def test_kernels_with_different_size(): ) +def test_implicit_block(): + _check(element_wise_implicit_block, unified_element_wise_implicit_block) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0) @@ -220,8 +256,4 @@ def test_lower_te(): if __name__ == "__main__": - test_thread_x() - test_vthread_x() - test_two_thread_x_in_same_kernel_not_equal() - test_kernels_with_different_size() - test_lower_te() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 80c37229f519..4c7ffd6ccaea 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,6 +18,7 @@ import pytest import sys import tvm +from tvm import tir from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -511,5 +512,77 @@ def render(e): # TODO(Siyuan): block iter errors. + +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + vl = T.axis.S(128, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block("C"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in T.serial(0, 128): + with T.block("B"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +def test_reorder_fail_block(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(l, i) + expected_sub_error_message = ( + " # tir.Block#0\n" + ' with T.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_reorder_fail_nested_loop_inner(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.reorder(k, i) + expected_sub_error_message = ( + " for i in T.serial(0, 128):\n" + " # tir.For#0\n" + " for j in T.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + +def test_fuse_fail_nested_loop_outer(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError) as execinfo: + sch.fuse(k, i) + expected_sub_error_message = ( + " # tir.For#1\n" + " for i in T.serial(0, 128):\n" + " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n" + " for j in T.serial(0, 128):\n" + ) + assert expected_sub_error_message in str(execinfo.value) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 7c54cdc85f82..4e1308b030f1 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3095,5 +3095,86 @@ def test_primfunc_with_allocate_annotations(): tvm.ir.assert_structural_equal(func, rt_func, True) +# fmt: off +@T.prim_func +def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + threadIdx_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128], dtype="float32") + for i in T.serial(0, 128): + T.launch_thread(threadIdx_x, 128) + reduce_temp0 = T.allocate([1], "float32", "local") + with T.attr(T.comm_reducer(lambda x0, x1, y0, y1: (T.Select((x1 >= y1), x0, y0), T.Select((x1 >= y1), x1, y1)), [T.int32(-1), T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")): + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), T.load("float32", A.data, i * 128 + threadIdx_x), True, reduce_temp0, threadIdx_x, dtype="handle")) + + +@T.prim_func +def multiple_commreducer() -> None: + normal_reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + normal_reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp0 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + reduce_temp1 = T.buffer_decl([1], dtype="float32", strides=[1], scope="local") + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle")) + for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("T_softmax_expsum_cross_thread_reduction"): + T.attr(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle")) + T.evaluate(T.tvm_thread_allreduce(T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle")) +# fmt: on + + +def test_primfunc_with_single_reduce_group_commreducer(): + func = comm_reducer_single_reduce_group + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_reduce_group_commreducer(): + func = comm_reducer_multiple_reduce_groups + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_primfunc_with_multiple_commreducer(): + func = multiple_commreducer + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +@T.prim_func +def func_div_mod(): + a = T.var("int32") + b = T.var("int32") + T.evaluate(a // b) + T.evaluate(a % b) + T.evaluate(a / b) + T.evaluate(T.truncmod(a, b)) + + +def test_div_mod(): + func = func_div_mod + rt_func = tvm.script.from_source(func.script()) + tvm.ir.assert_structural_equal(func, rt_func, True) + + assert isinstance(func.body[0].value, tvm.tir.FloorDiv) + assert isinstance(func.body[1].value, tvm.tir.FloorMod) + assert isinstance(func.body[2].value, tvm.tir.Div) + assert isinstance(func.body[3].value, tvm.tir.Mod) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_type.py b/tests/python/unittest/test_tvmscript_type.py new file mode 100644 index 000000000000..44ea04b5ed36 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_type.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement +from tvm.script import tir as T + +""" +This prim func include necessary buffer types that need to be checked +e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc. +""" + + +@T.prim_func +def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block("B"): + vi = T.axis.S(128, i0) + vj = T.axis.S(128, ax1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i1 in T.serial(0, 128): + with T.block("C"): + vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1) + + +""" +This prim func include necessary thread types that need to be checked +e.g. env_thread, launch_thread, thread_binding etc. +""" + + +@T.prim_func +def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + j1_0 = T.env_thread("threadIdx.x") + j0_0 = T.env_thread("threadIdx.x") + i = T.env_thread("blockIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + T.launch_thread(i, 128) + T.launch_thread(j0_0, 4) + T.launch_thread(j1_0, 4) + + for blockIdx_x in T.thread_binding(0, 128, "blockIdx.x"): + for threadIdx_x in T.thread_binding(0, 4, "threadIdx.x"): + for j0_1 in T.serial(0, 32): + with T.block(""): + B[blockIdx_x, threadIdx_x * 32 + j0_1] = ( + A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0 + ) + for j1_1 in T.serial(0, 32): + with T.block(""): + C[blockIdx_x, threadIdx_x * 32 + j1_1] = ( + B[blockIdx_x, threadIdx_x * 32 + j1_1] + 1.0 + ) + + +# Not running any test as we only want to type-check here +if __name__ == "__main__": + pass diff --git a/tests/scripts/task_cpp_unittest.sh b/tests/scripts/task_cpp_unittest.sh index 3df7b580d79d..0022e084c1a9 100755 --- a/tests/scripts/task_cpp_unittest.sh +++ b/tests/scripts/task_cpp_unittest.sh @@ -45,3 +45,11 @@ cd apps/bundle_deploy rm -rf build make test_dynamic test_static cd ../.. + +# Test Arm(R) Cortex(R)-M55 CPU and Ethos(TM)-U55 NPU demo app +FVP_PATH="/opt/arm/FVP_Corstone_SSE-300_Ethos-U55" +if test -d $FVP_PATH && pip3 list | grep vela; then + cd apps/microtvm/ethosu + ./run_demo.sh --fvp_path $FVP_PATH --cmake_path /opt/arm/cmake/bin/cmake + cd ../../.. +fi diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index aba4663d5931..0e20fc22cfb2 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -19,6 +19,7 @@ set -e set -u set -o pipefail +source tests/scripts/setup-pytest-env.sh echo "Checking MyPy Type defs in the TensorIR schedule package." mypy --check-untyped-defs python/tvm/tir/schedule @@ -32,6 +33,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ +echo "Checking MyPy Type defs in the TIR package with unittest" +MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py + #TODO(@mikepapadim): This is failing atm # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." # mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ diff --git a/tests/scripts/task_python_ethosn_tests.sh b/tests/scripts/task_python_ethosn_tests.sh index ae9b82b679ef..525cc26d743e 100755 --- a/tests/scripts/task_python_ethosn_tests.sh +++ b/tests/scripts/task_python_ethosn_tests.sh @@ -29,6 +29,6 @@ make cython3 # Note: Default behaviour is to assume the test target is Ethos-N77 # but setting ETHOSN_VARIANT_CONFIG appropriately -# (e.g. ETHOSN_VARIANT_CONFIG=ETHOSN78_1TOPS_4PLE_448KSRAM) +# (e.g. ETHOSN_VARIANT_CONFIG=Ethos-N78_1TOPS_2PLE_RATIO) # switches the target to an Ethos-N78 configuration. run_pytest ctypes python-ethosn tests/python/contrib/test_ethosn diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 00b63af48646..8618619d65ad 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -60,8 +60,8 @@ run_pytest cython ${TVM_INTEGRATION_TESTSUITE_NAME}-dso_plugin_module apps/dso_p # TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME} tests/python/integration -if python -c "import tvm; from tvm.relay.op.contrib.ethosn import ethosn_available; print(ethosn_available().name)" -eq "SW_ONLY"; then - ETHOSN_VARIANT_CONFIG=ETHOSN78_1TOPS_4PLE_448KSRAM run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib-test_ethosn tests/python/contrib/test_ethosn +if python3 -c "import tvm; from tvm.relay.op.contrib.ethosn import ethosn_available; print(ethosn_available().name)" -eq "SW_ONLY"; then + ETHOSN_VARIANT_CONFIG=Ethos-N78_1TOPS_2PLE_RATIO run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib-test_ethosn tests/python/contrib/test_ethosn fi run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib diff --git a/tests/scripts/task_python_integration_i386only.sh b/tests/scripts/task_python_integration_i386only.sh new file mode 100755 index 000000000000..9c378a647e3e --- /dev/null +++ b/tests/scripts/task_python_integration_i386only.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +set -e +set -u + +export TVM_INTEGRATION_I386_ONLY=1 + +./tests/scripts/task_python_integration.sh diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 6632ebb1ca52..8de8b908ee09 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -23,13 +23,20 @@ set -x # NOTE(areusch): Adding to diagnose flaky timeouts source tests/scripts/setup-pytest-env.sh make cython3 -run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=qemu_x86 + +# Zephyr +run_pytest ctypes python-microtvm-zephyr-qemu_x86 tests/micro/zephyr --zephyr-board=qemu_x86 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv32 tests/micro/zephyr --zephyr-board=qemu_riscv32 +run_pytest ctypes python-microtvm-zephyr-qemu_riscv64 tests/micro/zephyr --zephyr-board=qemu_riscv64 + # Temporarily removing mps2_an512 from CI due to issue 8728: # https://github.com/apache/tvm/issues/8728 # run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --zephyr-board=mps2_an521 +# Arduino run_pytest ctypes python-microtvm-arduino apps/microtvm/arduino/template_project/tests run_pytest ctypes python-microtvm-arduino-nano33ble tests/micro/arduino --test-build-only --arduino-board=nano33ble run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --arduino-board=due +# STM32 run_pytest ctypes python-microtvm-stm32 tests/micro/stm32 diff --git a/version.py b/version.py index 2fa0928b7e2e..a2212df73a4a 100644 --- a/version.py +++ b/version.py @@ -126,19 +126,20 @@ def update(file_name, pattern, repl, dry_run=False): update = [] hit_counter = 0 need_update = False - for l in open(file_name): - result = re.findall(pattern, l) - if result: - assert len(result) == 1 - hit_counter += 1 - if result[0] != repl: - l = re.sub(pattern, repl, l) - need_update = True - print("%s: %s -> %s" % (file_name, result[0], repl)) - else: - print("%s: version is already %s" % (file_name, repl)) - - update.append(l) + with open(file_name) as file: + for l in file: + result = re.findall(pattern, l) + if result: + assert len(result) == 1 + hit_counter += 1 + if result[0] != repl: + l = re.sub(pattern, repl, l) + need_update = True + print("%s: %s -> %s" % (file_name, result[0], repl)) + else: + print("%s: version is already %s" % (file_name, repl)) + + update.append(l) if hit_counter != 1: raise RuntimeError("Cannot find version in %s" % file_name)