diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index dd85ef2a5d172..1b9ebb3411e23 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -42,11 +42,6 @@ jobs: - uses: actions/checkout@v2 - name: Initialize submodules run: git submodule update --recursive --init - - name: Lint Python - if: startsWith(matrix.os, 'macOS') - run: | - python3 -m pip install flake8 - python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics - uses: actions/cache@v1 env: CACHE_NUMBER: 0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a481d703ebee..87fa2e059abb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -586,7 +586,7 @@ endif() # Create the `cpptest` target if we can find GTest. If not, we create dummy # targets that give the user an informative error message. if(GTEST_INCLUDE_DIR AND GTEST_LIB) - file(GLOB TEST_SRCS tests/cpp/*.cc) + file(GLOB_RECURSE TEST_SRCS tests/cpp/*.cc) add_executable(cpptest ${TEST_SRCS}) target_include_directories(cpptest SYSTEM PUBLIC ${GTEST_INCLUDE_DIR}) target_link_libraries(cpptest PRIVATE ${TVM_TEST_LIBRARY_NAME} ${GTEST_LIB} gtest_main pthread dl) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 14f8191707c86..b9ef0479c72f0 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -136,6 +136,7 @@ We do encourage everyone to work anything they are interested in. - [Jon Soifer](https://github.com/soiferj): @soiferj - [Zhixun Tan](https://github.com/phisiart): @phisiart - [Andrew Tulloch](https://github.com/ajtulloch): @ajtulloch +- [Jorn Tuyls](https://github.com/jtuyls): @jtuyls - [Luis Vega](https://github.com/vegaluisjose): @vegaluisjose - [Thomas Viehmann](https://github.com/t-vi): @t-vi - [Yao Wang](https://github.com/kevinthesun): @kevinthesun diff --git a/Jenkinsfile b/Jenkinsfile index fa16292050808..3a96fbee061d3 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -209,7 +209,7 @@ stage('Build') { make(ci_gpu, 'build', '-j2') pack_lib('gpu', tvm_multilib) // compiler test - sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu_vulkan.sh" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu_other.sh" make(ci_gpu, 'build2', '-j2') } } @@ -224,7 +224,6 @@ stage('Build') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" - sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_fsim.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_vta_tsim.sh" // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" @@ -300,6 +299,19 @@ stage('Unit Test') { } } }, + 'python3: CPU': { + node('CPU') { + ws(per_exec_ws("tvm/ut-python-cpu")) { + init_git() + unpack_lib('cpu', tvm_multilib_tsim) + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_ci_setup.sh" + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" + junit "build/pytest-results/*.xml" + } + } + } + }, 'python3: i386': { node('CPU') { ws(per_exec_ws("tvm/ut-python-i386")) { diff --git a/apps/ios_rpc/CMakeLists.txt b/apps/ios_rpc/CMakeLists.txt index 75f34b4dce05b..96d2d257d4ad1 100644 --- a/apps/ios_rpc/CMakeLists.txt +++ b/apps/ios_rpc/CMakeLists.txt @@ -39,7 +39,7 @@ endif() # It is required to load unsigned shared modules on real iOS devices ExternalProject_Add(custom_dso_loader GIT_REPOSITORY https://github.com/octoml/macho-dyld.git - GIT_TAG 48d1e8b5c40c7f5b744cb089634af17dd86125b2 + GIT_TAG 0742b8129de7df1130be355b74faa8c036265bfc PREFIX custom_dso_loader LOG_DOWNLOAD TRUE LOG_CONFIGURE TRUE @@ -54,24 +54,30 @@ ExternalProject_Add(custom_dso_loader -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=${CMAKE_BUILD_WITH_INSTALL_NAME_DIR} ) +if(NOT CMAKE_IOS_RPC_BUNDLE) + set(CMAKE_IOS_RPC_BUNDLE org.apache.tvmrpc) +endif() + # iOS RPC Xcode project wrapper to integrate into Cmake ExternalProject_Add(ios_rpc PREFIX ios_rpc - DEPENDS custom_dso_loader + DEPENDS custom_dso_loader tvm_runtime SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR} CONFIGURE_COMMAND "" INSTALL_COMMAND "" BUILD_COMMAND xcodebuild - -scheme tvmrpc + -target tvmrpc -configuration ${CMAKE_BUILD_TYPE} -project /tvmrpc.xcodeproj - -derivedDataPath -sdk ${CMAKE_OSX_SYSROOT} -arch ${CMAKE_OSX_ARCHITECTURES} -hideShellScriptEnvironment + -allowProvisioningUpdates build + SYMROOT= IPHONEOS_DEPLOYMENT_TARGET=${CMAKE_OSX_DEPLOYMENT_TARGET} DEVELOPMENT_TEAM=${CMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM} TVM_BUILD_DIR=${CMAKE_BINARY_DIR} - USE_CUSTOM_DSO_LOADER=YES + USE_CUSTOM_DSO_LOADER=1 + PRODUCT_BUNDLE_IDENTIFIER=${CMAKE_IOS_RPC_BUNDLE} ) diff --git a/apps/ios_rpc/README.md b/apps/ios_rpc/README.md index b2d9e199b3349..c268d15d01793 100644 --- a/apps/ios_rpc/README.md +++ b/apps/ios_rpc/README.md @@ -17,116 +17,240 @@ # iOS TVM RPC -This folder contains iOS RPC app that allows us to launch an rpc server on a iOS device(e.g. ipython) -and connect to it through python script and do testing on the python side as normal TVM RPC. -You will need XCode and an iOS device to use this. +This folder contains iOS RPC app that allows us to launch an rpc server on a iOS +device. You will need XCode and an iOS device to use this. -## RPC proxy -Start the RPC proxy by running in a terminal: - - python -m tvm.exec.rpc_proxy +## Table of Contents +* [Building](#building) + * [Building TVM runtime and custom DSO loader plugin](#building-tvm-runtime-and-custom-dso-loader-plugin) + * [Building iOS TVM RPC application](#building-ios-tvm-rpc-application) +* [Workflow](#workflow) + * [Standalone RPC](#standalone-rpc) + * [iOS RPC App with proxy](#ios-rpc-app-with-proxy) + * [iOS RPC App with tracker](#ios-rpc-app-with-tracker) +* [Communication without Wi-Fi and speed up in case of slow Wi-Fi](#communication-without-wi-fi-and-speed-up-in-case-of-slow-wi-fi) -On success, you should see something like this: - - INFO:root:RPCProxy: client port bind to 0.0.0.0:9090 - INFO:root:RPCProxy: Websock port bind to 8888 +## Building +### Building TVM runtime and custom DSO loader plugin +While iOS platform itself doesn't allow us to run an unsigned binary, there is a +partial ability to run JIT code on real iOS devices. While application is +running under debug session, system allows allocating memory with write and +execute permissions (a debugger requirement). So we can use this feature to +implement the `tvm.rpc.server.load_module` PackedFunc, used to load code over +RPC. For this purpose we use custom version of `dlopen` function which doesn't +check signature and permissions for module loading. This custom `dlopen` +mechanic is integrated into TVM RPC as plugin and registered to execution only +inside iOS RPC application. -IP-address of this machine will be used to initialize ```TVM_IOS_RPC_PROXY_HOST``` -environment variable (see below). +The custom implementation of `dlopen` and other functions from `dlfcn.h` header are placed in separate repository, +and will be downloaded automatically during cmake build for iOS. -## Building -Before start, please run ```init_proj.py``` to update XCode developer metadata. After this step, open -```tvmrpc.xcodeproj``` by using XCode, build the App and install the App on the phone. Usually, we -**do not** use the iOS App directly. +Also, it is necessary to build `libtvm_runtime.dylib` for our iOS device. The +iOS TVM RPC application will be linked with this library. -To test an App, you can fill ``Address`` field with IP-address of RPC proxy -(see above), and press ``Connect to Proxy``. +Run the build using the following commands: +```shell +export DEVELOPER_DIR=/Applications/Xcode.app # iOS SDK is part of Xcode bundle. Have to set it as default Dev Env +cmake .. + -DCMAKE_BUILD_TYPE=Debug + -DCMAKE_SYSTEM_NAME=iOS + -DCMAKE_SYSTEM_VERSION=14.0 + -DCMAKE_OSX_SYSROOT=iphoneos + -DCMAKE_OSX_ARCHITECTURES=arm64 + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 + -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON + -DUSE_IOS_RPC=ON # to enable build iOS RPC application from TVM project tree + -DUSE_METAL=ON # to enable Metal runtime -On success, "Disconnected" will change to "Connected". -On RPC proxy side you can see the next message in a log: +cmake --build . --target custom_dso_loader tvm_runtime +``` - INFO:root:Handler ready TCPSocketProxy::server:iphone +### Building iOS TVM RPC application +Before start, please run [init_proj.py](./init_proj.py) to update XCode developer metadata: +```shell +python3 init_proj.py --team_id XXXXXXXXXX --tvm_build_dir "/path/to/tvm/ios/build/folder" +``` +You can get value of your `team_id` in the following ways: +- **You have registered Apple Developer Profile**. In this case you developer + Team ID available at https://developer.apple.com/account/#/membership +- You are using your local developer profile. In this case, leave `XXXXXXXXXX` + in the command instead of substituting a Team ID. Then open `tvmrpc.xcodeproj` + by using XCode, click on the project name (`tvmrpc`) on the left panel. Then + select target `tvmrpc`. At the bottom of this panel go to `Signing & + Capabilities` tab and in the field `Team` select your local developer profile + (`Your Name (Personal Team)`). + + On the first run of the application you may see message `Could not launch + "tvmrpc"` in the XCode and message `Untrusted Developer` on your device. In + this case it will be necessary to check the certificate. Open + `Settings -> General -> Device Management -> Apple Development: + -> Trust "Apple Development: "` and click `Trust`. After than you + should rerun your application in the XCode. -Now App can be closed by pressing the home button (or even removed from a device). +After this step, open `tvmrpc.xcodeproj` by using XCode, build the App and +install the App on the phone. ## Workflow -Due to security restriction of iOS10. We cannot upload dynamic libraries to the App and load it from sandbox. -Instead, we need to build a list of libraries, pack them into the app bundle, launch the RPC server and -connect to test the bundled libraries. We use ```xcodebuild test``` to automate this process. There is also -one more approach to workaround this limitation, for more details please take a look into section -[Custom DSO loader integration](#custom-dso-loader-plugin). +Due to security restriction of iOS10. We cannot upload dynamic libraries to the +App and load it from sandbox. Instead, we need to build a list of libraries, +pack them into the app bundle, launch the RPC server and connect to test the +bundled libraries. For more on the approach we use to work around this +limitation, please take a look into section +[Building TVM runtime and custom DSO loader integration](#building-tvm-runtime-and-custom-DSO-loader-plugin). -The test script [tests/ios_rpc_test.py](tests/ios_rpc_test.py) is a good template for the workflow. With this -script, we don't need to manually operate the iOS App, this script will build the app, run it and collect the results -automatically. +The test script [tests/ios_rpc_test.py](tests/ios_rpc_test.py) and +[tests/ios_rpc_mobilenet.py](tests/ios_rpc_mobilenet.py) are good templates for +demonstrating the workflow. - To run the script, you need to configure the following environment variables +We have three different modes for iOS RPC server: +- [Standalone RPC](#standalone-rpc): In this mode RPC server open port on the device and listening. Then + client connects to the server directly without any mediators. +- [iOS RPC application with Proxy](#ios-rpc-app-with-proxy): RPC server and RPC client communicates through + `rpc_proxy`. The RPC server on iOS device notify `rpc_proxy` which was run on + host machine about itself and wait for incoming connections. Communications + between client and server works through `rpc_proxy`. +- [iOS RPC application with Tracker](#ios-rpc-app-with-tracker): RPC server registered in the `rpc_tracker` + and client connects to the RPC server through `rpc_tracker`. -- ```TVM_IOS_CODESIGN``` The signature you use to codesign the app and libraries (e.g. ```iPhone Developer: Name (XXXX)```) -- ```TVM_IOS_TEAM_ID``` The developer Team ID available at https://developer.apple.com/account/#/membership -- ```TVM_IOS_RPC_ROOT``` The root directory of the iOS rpc project -- ```TVM_IOS_RPC_PROXY_HOST``` The RPC proxy address (see above) -- ```TVM_IOS_RPC_DESTINATION``` The Xcode target device (e.g. ```platform=iOS,id=xxxx```) +### Standalone RPC +Start RPC server on your iOS device: +- Push on the `Connect` button. -See instructions of how to find UUID of the iOS device: +After that you supposed to see something like this in the app on the device: +``` +IP: +Port: +``` -- https://www.innerfence.com/howto/find-iphone-unique-device-identifier-udid +Printed `IP` is the IP address of your device and `PORT` is the number of port +which was open for RPC connection. Next you should use them for connect your RPC +client to the server. -## How it works -Let us explain how it works, the project look for ```rpc_config.txt``` file in the project root folder. -The ```rpc_config.txt``` file should be in the following format: +Let's check that direct RPC connection works and we can upload a library with +model and execute it on the device. For this purpose we will use +[ios_rpc_test.py](tests/ios_rpc_test.py). Run it: +```shell +python3 tests/ios_rpc_test.py --host --port --mode "standalone" ``` - -[path to dylib1] -[path to dylib2] -... +This will compile TVM IR to shared libraries (CPU and Metal) and run vector +addition on your iOS device. You are supposed to see something like this: +``` +Metal: 0.000338692 secs/op +CPU: 0.000219308 secs/op ``` -The build script will copy all the dynamic libraries into bundle ```tvmrpc.app/Frameworks/tvm```, -which you will be able to load via RPC using ```remote.load_module```. -It will also create an ```tvmrpc.app/Frameworks/tvm/rpc_config.txt``` containing the first line. -When we run the testcase, the testcase read the configuration from ```tvmrpc.app/Frameworks/tvm/rpc_config.txt``` -and connect to the specified RPC proxy, start serving loop. +### iOS RPC App with proxy +Start the RPC proxy by running in a terminal: +```shell +python3 -m tvm.exec.rpc_proxy --host 0.0.0.0 --port 9090 +``` -So if we want to start the RPC from XCode IDE, simply manually modify ```rpc_config.txt``` file and click test. -Then connect to the proxy via the python script. +On success, you should see something like this: +``` +INFO:root:RPCProxy: client port bind to 0.0.0.0:9090 +INFO:root:RPCProxy: Websock port bind to 8888 +``` +Connect your iOS device to the RPC proxy via the iOS TVM RPC application. Set +the `Address` and `Port` fields to the address and port of the RPC tracker +respectively. Select mode `Proxy` and push `Connect` button. In success the +text on the button will be changed to `Disconnect` and `Disconnected` in the top +of the screen will be changed to `Connected`. +On RPC proxy side you can see the next message in a log: +``` +INFO:root:Handler ready TCPSocketProxy::server:iphone +``` +Then we can check that RPC connection works and we can upload a library with +model and execute it on the target device. For this purpose we will use +[ios_rpc_test.py](tests/ios_rpc_test.py). Run it: +```shell +python3 tests/ios_rpc_test.py --host --port 9090 --mode "proxy" +``` +The output should be the same as it was in previous section. -We can also use the RPC App directly, by typing in the address and press connect to connect to the proxy. -However, the restriction is we can only load the modules that are bundled to the App. +### iOS RPC App with tracker +First start an RPC tracker using +```shell +python3 -m tvm.exec.rpc_tracker --host 0.0.0.0 --port 9190 +``` +On success, you should see something like this: +``` +INFO:RPCTracker:bind to 0.0.0.0:9190 +``` +Connect your iOS device to the RPC tracker via the iOS TVM RPC applcation. Set +the `Address` and `Port` fields to the address and port of the RPC tracker +respectively. Select mode `Tracker` and push `Connect` button. In success the +text on the button will be changed to `Disconnect` and `Disconnected` in the top +of the screen will be changed to `Connected`. On the host side you can check the +connect by the following command: +```shell +python3 -m tvm.exec.query_rpc_tracker --port 9190 +``` +You are supposed to see something like this: +``` +Tracker address 127.0.0.1:9190 -## Custom DSO loader plugin -While iOS platform itself doesn't allow us to run an unsigned binary, where is a partial ability to run JIT code -on real iOS devices. While application is running under debug session, system allows allocating memory with write -and execute permissions (requirements of debugger). So we can use this feature to load binary on RPC side. For this -purpose we use custom version of `dlopen` function which doesn't check signature and permissions for module loading. -This custom `dlopen` mechanic is integrated into TVM RPC as plugin and registered to execution only inside iOS RPC -application. +Server List +---------------------------- +server-address key +---------------------------- +192.168.1.57:9190 server:iphone +---------------------------- -The custom implementation of `dlopen` and other functions from `dlfcn.h` header are placed in separate repository, -and will be downloaded automatically during cmake build for iOS. To run cmake build you may use next flags: +Queue Status +------------------------------ +key total free pending +------------------------------ +iphone 1 1 0 +------------------------------ +``` + +Then we can check that RPC connection works and we can upload a library with +model and execute it on the target device. For this purpose we will use +[ios_rpc_test.py](tests/ios_rpc_test.py). Run it: ```shell -export DEVELOPER_DIR=/Applications/Xcode.app # iOS SDK is part of Xcode bundle. Have to set it as default Dev Env -cmake .. - -DCMAKE_BUILD_TYPE=Debug - -DCMAKE_SYSTEM_NAME=iOS - -DCMAKE_SYSTEM_VERSION=14.0 - -DCMAKE_OSX_SYSROOT=iphoneos - -DCMAKE_OSX_ARCHITECTURES=arm64 - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 - -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON - -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=XXXXXXXXXX # insert your Team ID - -DUSE_IOS_RPC=ON # to enable build iOS RPC application from TVM project tree -cmake --build . --target custom_dso_loader ios_rpc # Will use custom DSO loader by default -# Resulting iOS RPC app bundle will be placed in: -# apps/ios_rpc/ios_rpc/src/ios_rpc-build/Build/Products/[CONFIG]-iphoneos/tvmrpc.app +python3 tests/ios_rpc_test.py --host --port 9190 --mode "tracker" ``` +The output will be the same as in section +[Standalone RPC](#standalone-rpc). + +## Communication without Wi-Fi and speed up in case of slow Wi-Fi +Connection to the RPC server through `usbmux` can be used then you have slow, +unstable or don't have any Wi-Fi connection. `usbmux` is used for binding local +TCP port to port on the device and transfer packages between these ports by USB +cable. -To enable using of Custom DSO Plugin during xcode build outsde of Cmake you should specify two additional variables. -You can do it manually inside Xcode IDE or via command line args for `xcodebuild`. Make sure that `custom_dso_loader` -target from previous step is already built. -* TVM_BUILD_DIR=path-to-tvm-ios-build-dir -* USE_CUSTOM_DSO_LOADER=1 +First of all you should install `usbmux` to your system. You can do it with +brew: +```shell +brew install usbmuxd +``` +After that you can use `iproxy` program for binding ports. You can use it for +all described workflows. Let's take a look how it works for +[Standalone RPC](#standalone-rpc). -iOS RPC application with enabled custom DSO loader is able to process modules passed via regular -`remote.upload("my_module.dylib")` mechanics. For example take a look inside `test_rpc_module_with_upload` test case -of file [ios_rpc_test.py](tests/ios_rpc_test.py). +First, start RPC server on your iOS device. You may see something like this in +the app on the device: +``` +IP: unknown +Port: +``` +**Note.** Here `IP: unknown` because there was no Internet connection on the iOS +device. +Printed `Port` is the port of the RPC server on your iOS device. We will use it +in binding ports. Run `iproxy`, specify local port which should be used for +communication with device and the printed port on the device: +```shell +iproxy : +``` +After this command you should see something like this: +``` +Creating listening port for device port +waiting for connection +``` +Now we can check that RPC connection through `usbmux` works and we can upload a +library with model and execute it on the device. For this purpose we will use +[ios_rpc_test.py](tests/ios_rpc_test.py). Run it: +```shell +python3 tests/ios_rpc_test.py --host 0.0.0.0 --port --mode standalone +``` +The output should be the same as in all previous runs. diff --git a/apps/ios_rpc/init_proj.py b/apps/ios_rpc/init_proj.py index deee86ce4876d..9044a9e8cbbf0 100644 --- a/apps/ios_rpc/init_proj.py +++ b/apps/ios_rpc/init_proj.py @@ -18,7 +18,7 @@ import re default_team_id = "3FR42MXLK9" -default_bundle_identifier = "org.apache.tvmrpc" +default_tvm_build_dir = "path-to-tvm-ios-build-folder" parser = argparse.ArgumentParser( description="Update tvmrpc.xcodeproj\ @@ -38,26 +38,22 @@ ) parser.add_argument( - "--bundle_identifier", + "--tvm_build_dir", type=str, - required=False, - default=default_bundle_identifier, - help="The new bundle identifier\n\ - (example: {})".format( - default_bundle_identifier - ), + required=True, + help="Path to directory with libtvm_runtime.dylib", ) args = parser.parse_args() team_id = args.team_id -bundle_identifier = args.bundle_identifier +tvm_build_dir = args.tvm_build_dir fi = open("tvmrpc.xcodeproj/project.pbxproj") proj_config = fi.read() fi.close() proj_config = proj_config.replace(default_team_id, team_id) -proj_config = proj_config.replace(default_bundle_identifier, bundle_identifier) +proj_config = proj_config.replace(default_tvm_build_dir, tvm_build_dir) fo = open("tvmrpc.xcodeproj/project.pbxproj", "w") fo.write(proj_config) fo.close() diff --git a/apps/ios_rpc/tests/ios_rpc_mobilenet.py b/apps/ios_rpc/tests/ios_rpc_mobilenet.py index a57db729d8a6b..e4caa329bdf59 100644 --- a/apps/ios_rpc/tests/ios_rpc_mobilenet.py +++ b/apps/ios_rpc/tests/ios_rpc_mobilenet.py @@ -32,20 +32,7 @@ from mxnet import gluon from PIL import Image import coremltools - -# Set to be address of tvm proxy. -proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"] -# Set your desination via env variable. -# Should in format "platform=iOS,id=" -destination = os.environ["TVM_IOS_RPC_DESTINATION"] - -if not re.match(r"^platform=.*,id=.*$", destination): - print("Bad format: {}".format(destination)) - print("Example of expected string: platform=iOS,id=1234567890abcabcabcabc1234567890abcabcab") - sys.exit(1) - -proxy_port = 9090 -key = "iphone" +import argparse # Change target configuration, this is setting for iphone6s # arch = "x86_64" @@ -54,6 +41,8 @@ sdk = "iphoneos" target_host = "llvm -mtriple=%s-apple-darwin" % arch +MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect} + # override metal compiler to compile to iphone @tvm.register_func("tvm_callback_metal_compile") def compile_metal(src): @@ -97,7 +86,7 @@ def get_model(model_name, data_shape): return func, params -def test_mobilenet(): +def test_mobilenet(host, port, key, mode): temp = utils.tempdir() image, synset = prepare_input() model, params = get_model("mobilenetv2_1.0", image.shape) @@ -107,13 +96,13 @@ def run(mod, target): lib = relay.build(mod, target=target, target_host=target_host, params=params) path_dso = temp.relpath("deploy.dylib") lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk) - xcode.codesign(path_dso) - - # Start RPC test server that contains the compiled library. - xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination, libs=[path_dso]) # connect to the proxy - remote = rpc.connect(proxy_host, proxy_port, key=key) + if mode == "tracker": + remote = MODES[mode](host, port).request(key) + else: + remote = MODES[mode](host, port, key=key) + remote.upload(path_dso) if target == "metal": dev = remote.metal(0) @@ -175,4 +164,19 @@ def annotate(func, compiler): if __name__ == "__main__": - test_mobilenet() + parser = argparse.ArgumentParser(description="Demo app demonstrates how ios_rpc works.") + parser.add_argument("--host", required=True, type=str, help="Adress of rpc server") + parser.add_argument("--port", type=int, default=9090, help="rpc port (default: 9090)") + parser.add_argument("--key", type=str, default="iphone", help="device key (default: iphone)") + parser.add_argument( + "--mode", + type=str, + default="tracker", + help="type of RPC connection (default: tracker), possible values: {}".format( + ", ".join(MODES.keys()) + ), + ) + + args = parser.parse_args() + assert args.mode in MODES.keys() + test_mobilenet(args.host, args.port, args.key, args.mode) diff --git a/apps/ios_rpc/tests/ios_rpc_test.py b/apps/ios_rpc/tests/ios_rpc_test.py index 0f81dcce929f0..733ab912ecfc9 100644 --- a/apps/ios_rpc/tests/ios_rpc_test.py +++ b/apps/ios_rpc/tests/ios_rpc_test.py @@ -28,33 +28,22 @@ from tvm import rpc from tvm.contrib import utils, xcode import numpy as np - -# Set to be address of tvm proxy. -proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"] -# Set your destination via env variable. -# Should in format "platform=iOS,id=" -destination = os.environ["TVM_IOS_RPC_DESTINATION"] - -if not re.match(r"^platform=.*,id=.*$", destination): - print("Bad format: {}".format(destination)) - print("Example of expected string: platform=iOS,id=1234567890abcabcabcabc1234567890abcabcab") - sys.exit(1) - -proxy_port = 9090 -key = "iphone" +import argparse # Change target configuration, this is setting for iphone6s arch = "arm64" sdk = "iphoneos" target = "llvm -mtriple=%s-apple-darwin" % arch +MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect} + # override metal compiler to compile to iphone @tvm.register_func("tvm_callback_metal_compile") def compile_metal(src): return xcode.compile_metal(src, sdk=sdk) -def test_rpc_module(): +def test_rpc_module(host, port, key, mode): # graph n = tvm.runtime.convert(1024) A = te.placeholder((n,), name="A") @@ -69,7 +58,6 @@ def test_rpc_module(): f = tvm.build(s, [A, B], "metal", target_host=target, name="myadd") path_dso1 = temp.relpath("dev_lib.dylib") f.export_library(path_dso1, xcode.create_dylib, arch=arch, sdk=sdk) - xcode.codesign(path_dso1) s = te.create_schedule(B.op) xo, xi = s[B].split(B.op.axis[0], factor=64) @@ -79,15 +67,13 @@ def test_rpc_module(): f = tvm.build(s, [A, B], target, name="myadd_cpu") path_dso2 = temp.relpath("cpu_lib.dylib") f.export_library(path_dso2, xcode.create_dylib, arch=arch, sdk=sdk) - xcode.codesign(path_dso2) - - # Start RPC test server that contains the compiled library. - server = xcode.popen_test_rpc( - proxy_host, proxy_port, key, destination=destination, libs=[path_dso1, path_dso2] - ) # connect to the proxy - remote = rpc.connect(proxy_host, proxy_port, key=key) + if mode == "tracker": + remote = MODES[mode](host, port).request(key) + else: + remote = MODES[mode](host, port, key=key) + remote.upload(path_dso1) dev = remote.metal(0) f1 = remote.load_module("dev_lib.dylib") a_np = np.random.uniform(size=1024).astype(A.dtype) @@ -95,56 +81,35 @@ def test_rpc_module(): b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) time_f = f1.time_evaluator(f1.entry_name, dev, number=10) cost = time_f(a, b).mean - print("%g secs/op" % cost) + print("Metal: %g secs/op" % cost) np.testing.assert_equal(b.numpy(), a.numpy() + 1) # CPU dev = remote.cpu(0) + remote.upload(path_dso2) f2 = remote.load_module("cpu_lib.dylib") a_np = np.random.uniform(size=1024).astype(A.dtype) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) time_f = f2.time_evaluator(f2.entry_name, dev, number=10) cost = time_f(a, b).mean - print("%g secs/op" % cost) - np.testing.assert_equal(b.numpy(), a.numpy() + 1) - - -def test_rpc_module_with_upload(): - server = xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination) - - remote = rpc.connect(proxy_host, proxy_port, key=key) - try: - remote.get_function("runtime.module.loadfile_dylib_custom") - except AttributeError as e: - print(e) - print("Skip test. You are using iOS RPC without custom DSO loader enabled.") - return - - n = tvm.runtime.convert(1024) - A = te.placeholder((n,), name="A") - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") - temp = utils.tempdir() - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=64) - s[B].parallel(xi) - s[B].pragma(xo, "parallel_launch_point") - s[B].pragma(xi, "parallel_barrier_when_finish") - f = tvm.build(s, [A, B], target, name="myadd_cpu") - path_dso = temp.relpath("cpu_lib.dylib") - f.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk) - - dev = remote.cpu(0) - remote.upload(path_dso) - f = remote.load_module("cpu_lib.dylib") - a_np = np.random.uniform(size=1024).astype(A.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), dev) - time_f = f.time_evaluator(f.entry_name, dev, number=10) - cost = time_f(a, b).mean - print("%g secs/op" % cost) + print("CPU: %g secs/op" % cost) np.testing.assert_equal(b.numpy(), a.numpy() + 1) if __name__ == "__main__": - test_rpc_module() - test_rpc_module_with_upload() + parser = argparse.ArgumentParser(description="Demo app demonstrates how ios_rpc works.") + parser.add_argument("--host", required=True, type=str, help="Adress of rpc server") + parser.add_argument("--port", type=int, default=9090, help="rpc port (default: 9090)") + parser.add_argument("--key", type=str, default="iphone", help="device key (default: iphone)") + parser.add_argument( + "--mode", + type=str, + default="tracker", + help="type of RPC connection (default: tracker), possible values: {}".format( + ", ".join(MODES.keys()) + ), + ) + + args = parser.parse_args() + assert args.mode in MODES.keys() + test_rpc_module(args.host, args.port, args.key, args.mode) diff --git a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj index a5b69c829e4f7..61427d0ca248b 100644 --- a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj +++ b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj @@ -26,6 +26,10 @@ objects = { /* Begin PBXBuildFile section */ + 016B19C22657B390002E1719 /* RPCServer.mm in Sources */ = {isa = PBXBuildFile; fileRef = 016B19C12657B390002E1719 /* RPCServer.mm */; }; + 01A1DB432652CBA700655BBC /* RPCArgs.mm in Sources */ = {isa = PBXBuildFile; fileRef = 01A1DB412652CBA700655BBC /* RPCArgs.mm */; }; + 01A9B7B3265BD1FD000D092F /* libtvm_runtime.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = 01A9B7B2265BD1FD000D092F /* libtvm_runtime.dylib */; }; + 01A9B7B8265BD307000D092F /* libtvm_runtime.dylib in Embed Libraries */ = {isa = PBXBuildFile; fileRef = 01A9B7B2265BD1FD000D092F /* libtvm_runtime.dylib */; settings = {ATTRIBUTES = (CodeSignOnCopy, ); }; }; C02637501F1C25E8007247A9 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = C026374F1F1C25E8007247A9 /* main.m */; }; C02637531F1C25E8007247A9 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = C02637521F1C25E8007247A9 /* AppDelegate.m */; }; C02637591F1C25E8007247A9 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = C02637571F1C25E8007247A9 /* Main.storyboard */; }; @@ -33,21 +37,29 @@ C026375E1F1C25E8007247A9 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = C026375C1F1C25E8007247A9 /* LaunchScreen.storyboard */; }; C02637661F1C2690007247A9 /* TVMRuntime.mm in Sources */ = {isa = PBXBuildFile; fileRef = C02637651F1C2690007247A9 /* TVMRuntime.mm */; }; C02637691F1C26AF007247A9 /* ViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = C02637681F1C26AF007247A9 /* ViewController.mm */; }; - C05A2C891F1DCE0900D4798B /* tvmrpcLauncher.mm in Sources */ = {isa = PBXBuildFile; fileRef = C05A2C881F1DCE0900D4798B /* tvmrpcLauncher.mm */; }; D7685AD324390EAE00D1469C /* CoreML.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = D7685AD224390EAD00D1469C /* CoreML.framework */; }; /* End PBXBuildFile section */ -/* Begin PBXContainerItemProxy section */ - C05A2C8B1F1DCE0900D4798B /* PBXContainerItemProxy */ = { - isa = PBXContainerItemProxy; - containerPortal = C02637431F1C25E8007247A9 /* Project object */; - proxyType = 1; - remoteGlobalIDString = C026374A1F1C25E8007247A9; - remoteInfo = tvmrpc; +/* Begin PBXCopyFilesBuildPhase section */ + 01A9B7B9265BD307000D092F /* Embed Libraries */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + 01A9B7B8265BD307000D092F /* libtvm_runtime.dylib in Embed Libraries */, + ); + name = "Embed Libraries"; + runOnlyForDeploymentPostprocessing = 0; }; -/* End PBXContainerItemProxy section */ +/* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ + 016B19C02657B390002E1719 /* RPCServer.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RPCServer.h; sourceTree = ""; }; + 016B19C12657B390002E1719 /* RPCServer.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RPCServer.mm; sourceTree = ""; }; + 01A1DB402652CBA700655BBC /* RPCArgs.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RPCArgs.h; sourceTree = ""; }; + 01A1DB412652CBA700655BBC /* RPCArgs.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RPCArgs.mm; sourceTree = ""; }; + 01A9B7B2265BD1FD000D092F /* libtvm_runtime.dylib */ = {isa = PBXFileReference; lastKnownFileType = "compiled.mach-o.dylib"; name = libtvm_runtime.dylib; path = "${TVM_BUILD_DIR}/libtvm_runtime.dylib"; sourceTree = ""; }; C026374B1F1C25E8007247A9 /* tvmrpc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tvmrpc.app; sourceTree = BUILT_PRODUCTS_DIR; }; C026374F1F1C25E8007247A9 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; C02637511F1C25E8007247A9 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; @@ -58,11 +70,7 @@ C026375D1F1C25E8007247A9 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = ""; }; C026375F1F1C25E8007247A9 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; C02637651F1C2690007247A9 /* TVMRuntime.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = TVMRuntime.mm; sourceTree = ""; }; - C02637671F1C269B007247A9 /* TVMRuntime.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = TVMRuntime.h; sourceTree = ""; }; C02637681F1C26AF007247A9 /* ViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ViewController.mm; sourceTree = ""; }; - C05A2C861F1DCE0900D4798B /* tvmrpcLauncher.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = tvmrpcLauncher.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; - C05A2C881F1DCE0900D4798B /* tvmrpcLauncher.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = tvmrpcLauncher.mm; sourceTree = ""; }; - C05A2C8A1F1DCE0900D4798B /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; D7685AD224390EAD00D1469C /* CoreML.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreML.framework; path = System/Library/Frameworks/CoreML.framework; sourceTree = SDKROOT; }; /* End PBXFileReference section */ @@ -71,17 +79,11 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + 01A9B7B3265BD1FD000D092F /* libtvm_runtime.dylib in Frameworks */, D7685AD324390EAE00D1469C /* CoreML.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; - C05A2C831F1DCE0900D4798B /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ @@ -89,7 +91,6 @@ isa = PBXGroup; children = ( C026374D1F1C25E8007247A9 /* tvmrpc */, - C05A2C871F1DCE0900D4798B /* tvmrpcLauncher */, C026374C1F1C25E8007247A9 /* Products */, D7685AD124390EAD00D1469C /* Frameworks */, ); @@ -101,7 +102,6 @@ isa = PBXGroup; children = ( C026374B1F1C25E8007247A9 /* tvmrpc.app */, - C05A2C861F1DCE0900D4798B /* tvmrpcLauncher.xctest */, ); name = Products; sourceTree = ""; @@ -109,12 +109,15 @@ C026374D1F1C25E8007247A9 /* tvmrpc */ = { isa = PBXGroup; children = ( - C02637681F1C26AF007247A9 /* ViewController.mm */, - C02637671F1C269B007247A9 /* TVMRuntime.h */, + 016B19C02657B390002E1719 /* RPCServer.h */, + 016B19C12657B390002E1719 /* RPCServer.mm */, + 01A1DB402652CBA700655BBC /* RPCArgs.h */, + 01A1DB412652CBA700655BBC /* RPCArgs.mm */, C02637651F1C2690007247A9 /* TVMRuntime.mm */, C02637511F1C25E8007247A9 /* AppDelegate.h */, C02637521F1C25E8007247A9 /* AppDelegate.m */, C02637541F1C25E8007247A9 /* ViewController.h */, + C02637681F1C26AF007247A9 /* ViewController.mm */, C02637571F1C25E8007247A9 /* Main.storyboard */, C026375A1F1C25E8007247A9 /* Assets.xcassets */, C026375C1F1C25E8007247A9 /* LaunchScreen.storyboard */, @@ -132,18 +135,10 @@ name = "Supporting Files"; sourceTree = ""; }; - C05A2C871F1DCE0900D4798B /* tvmrpcLauncher */ = { - isa = PBXGroup; - children = ( - C05A2C881F1DCE0900D4798B /* tvmrpcLauncher.mm */, - C05A2C8A1F1DCE0900D4798B /* Info.plist */, - ); - path = tvmrpcLauncher; - sourceTree = ""; - }; D7685AD124390EAD00D1469C /* Frameworks */ = { isa = PBXGroup; children = ( + 01A9B7B2265BD1FD000D092F /* libtvm_runtime.dylib */, D7685AD224390EAD00D1469C /* CoreML.framework */, ); name = Frameworks; @@ -159,7 +154,7 @@ C02637471F1C25E8007247A9 /* Sources */, C02637481F1C25E8007247A9 /* Frameworks */, C02637491F1C25E8007247A9 /* Resources */, - C05A2C901F1E683A00D4798B /* ShellScript */, + 01A9B7B9265BD307000D092F /* Embed Libraries */, ); buildRules = ( ); @@ -170,24 +165,6 @@ productReference = C026374B1F1C25E8007247A9 /* tvmrpc.app */; productType = "com.apple.product-type.application"; }; - C05A2C851F1DCE0900D4798B /* tvmrpcLauncher */ = { - isa = PBXNativeTarget; - buildConfigurationList = C05A2C8F1F1DCE0900D4798B /* Build configuration list for PBXNativeTarget "tvmrpcLauncher" */; - buildPhases = ( - C05A2C821F1DCE0900D4798B /* Sources */, - C05A2C831F1DCE0900D4798B /* Frameworks */, - C05A2C841F1DCE0900D4798B /* Resources */, - ); - buildRules = ( - ); - dependencies = ( - C05A2C8C1F1DCE0900D4798B /* PBXTargetDependency */, - ); - name = tvmrpcLauncher; - productName = tvmrpcLauncher; - productReference = C05A2C861F1DCE0900D4798B /* tvmrpcLauncher.xctest */; - productType = "com.apple.product-type.bundle.unit-test"; - }; /* End PBXNativeTarget section */ /* Begin PBXProject section */ @@ -202,12 +179,6 @@ DevelopmentTeam = 3FR42MXLK9; ProvisioningStyle = Automatic; }; - C05A2C851F1DCE0900D4798B = { - CreatedOnToolsVersion = 8.3.3; - DevelopmentTeam = 3FR42MXLK9; - ProvisioningStyle = Automatic; - TestTargetID = C026374A1F1C25E8007247A9; - }; }; }; buildConfigurationList = C02637461F1C25E8007247A9 /* Build configuration list for PBXProject "tvmrpc" */; @@ -215,6 +186,7 @@ developmentRegion = English; hasScannedForEncodings = 0; knownRegions = ( + English, en, Base, ); @@ -224,7 +196,6 @@ projectRoot = ""; targets = ( C026374A1F1C25E8007247A9 /* tvmrpc */, - C05A2C851F1DCE0900D4798B /* tvmrpcLauncher */, ); }; /* End PBXProject section */ @@ -240,61 +211,24 @@ ); runOnlyForDeploymentPostprocessing = 0; }; - C05A2C841F1DCE0900D4798B /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - runOnlyForDeploymentPostprocessing = 0; - }; /* End PBXResourcesBuildPhase section */ -/* Begin PBXShellScriptBuildPhase section */ - C05A2C901F1E683A00D4798B /* ShellScript */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "libpath=${CONFIGURATION_BUILD_DIR}/${CONTENTS_FOLDER_PATH}/Frameworks/tvm\nmkdir -p ${libpath}\nrm -rf ${libpath}/*\n \nif [ -f ${SRCROOT}/rpc_config.txt ]; then\n head -n 1 ${SRCROOT}/rpc_config.txt > ${libpath}/rpc_config.txt\n tail -n +2 ${SRCROOT}/rpc_config.txt | xargs -J % cp -r % ${libpath}\nfi\n\n"; - }; -/* End PBXShellScriptBuildPhase section */ - /* Begin PBXSourcesBuildPhase section */ C02637471F1C25E8007247A9 /* Sources */ = { isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( C02637691F1C26AF007247A9 /* ViewController.mm in Sources */, + 01A1DB432652CBA700655BBC /* RPCArgs.mm in Sources */, + 016B19C22657B390002E1719 /* RPCServer.mm in Sources */, C02637531F1C25E8007247A9 /* AppDelegate.m in Sources */, C02637661F1C2690007247A9 /* TVMRuntime.mm in Sources */, C02637501F1C25E8007247A9 /* main.m in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; - C05A2C821F1DCE0900D4798B /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - C05A2C891F1DCE0900D4798B /* tvmrpcLauncher.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; /* End PBXSourcesBuildPhase section */ -/* Begin PBXTargetDependency section */ - C05A2C8C1F1DCE0900D4798B /* PBXTargetDependency */ = { - isa = PBXTargetDependency; - target = C026374A1F1C25E8007247A9 /* tvmrpc */; - targetProxy = C05A2C8B1F1DCE0900D4798B /* PBXContainerItemProxy */; - }; -/* End PBXTargetDependency section */ - /* Begin PBXVariantGroup section */ C02637571F1C25E8007247A9 /* Main.storyboard */ = { isa = PBXVariantGroup; @@ -340,6 +274,7 @@ "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; COPY_PHASE_STRIP = NO; DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_BITCODE = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; ENABLE_TESTABILITY = YES; GCC_C_LANGUAGE_STANDARD = gnu99; @@ -351,6 +286,7 @@ "$(inherited)", "DMLC_USE_LOGGING_LIBRARY=", "TVM_USE_LIBBACKTRACE=0", + "TVM_LOG_CUSTOMIZE=1", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; @@ -391,6 +327,7 @@ "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; COPY_PHASE_STRIP = NO; DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_BITCODE = NO; ENABLE_NS_ASSERTIONS = NO; ENABLE_STRICT_OBJC_MSGSEND = YES; GCC_C_LANGUAGE_STANDARD = gnu99; @@ -398,6 +335,7 @@ GCC_PREPROCESSOR_DEFINITIONS = ( "DMLC_USE_LOGGING_LIBRARY=", "TVM_USE_LIBBACKTRACE=0", + "TVM_LOG_CUSTOMIZE=1", ); GCC_WARN_64_TO_32_BIT_CONVERSION = YES; GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; @@ -417,7 +355,6 @@ isa = XCBuildConfiguration; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ENABLE_OBJC_ARC = NO; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_DOCUMENTATION_COMMENTS = NO; DEVELOPMENT_TEAM = 3FR42MXLK9; @@ -439,10 +376,10 @@ "${TVM_BUILD_DIR}", ); OTHER_LDFLAGS = "${_DSO_LOADER_NAME_${USE_CUSTOM_DSO_LOADER}}"; - PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmrpc; + PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmiosrpc; PRODUCT_NAME = "$(TARGET_NAME)"; TVM_BUILD_DIR = "path-to-tvm-ios-build-folder"; - USE_CUSTOM_DSO_LOADER = 0; + USE_CUSTOM_DSO_LOADER = 1; WARNING_CFLAGS = "-Wno-shorten-64-to-32"; _DSO_LOADER_NAME_0 = ""; _DSO_LOADER_NAME_1 = "-lmacho_dyld"; @@ -453,7 +390,6 @@ isa = XCBuildConfiguration; buildSettings = { ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ENABLE_OBJC_ARC = NO; CLANG_ENABLE_OBJC_WEAK = YES; CLANG_WARN_DOCUMENTATION_COMMENTS = NO; DEVELOPMENT_TEAM = 3FR42MXLK9; @@ -475,52 +411,16 @@ "${TVM_BUILD_DIR}", ); OTHER_LDFLAGS = "${_DSO_LOADER_NAME_${USE_CUSTOM_DSO_LOADER}}"; - PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmrpc; + PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmiosrpc; PRODUCT_NAME = "$(TARGET_NAME)"; TVM_BUILD_DIR = "path-to-tvm-ios-build-folder"; - USE_CUSTOM_DSO_LOADER = 0; + USE_CUSTOM_DSO_LOADER = 1; WARNING_CFLAGS = "-Wno-shorten-64-to-32"; _DSO_LOADER_NAME_0 = ""; _DSO_LOADER_NAME_1 = "-lmacho_dyld"; }; name = Release; }; - C05A2C8D1F1DCE0900D4798B /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - BUNDLE_LOADER = "$(TEST_HOST)"; - DEVELOPMENT_TEAM = 3FR42MXLK9; - HEADER_SEARCH_PATHS = ( - ../../3rdparty/dlpack/include, - ../../include, - "../../3rdparty/dmlc-core/include", - ); - INFOPLIST_FILE = tvmrpcLauncher/Info.plist; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; - PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmrpcLauncher; - PRODUCT_NAME = "$(TARGET_NAME)"; - TEST_HOST = "$(BUILT_PRODUCTS_DIR)/tvmrpc.app/tvmrpc"; - }; - name = Debug; - }; - C05A2C8E1F1DCE0900D4798B /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - BUNDLE_LOADER = "$(TEST_HOST)"; - DEVELOPMENT_TEAM = 3FR42MXLK9; - HEADER_SEARCH_PATHS = ( - ../../3rdparty/dlpack/include, - ../../include, - "../../3rdparty/dmlc-core/include", - ); - INFOPLIST_FILE = tvmrpcLauncher/Info.plist; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; - PRODUCT_BUNDLE_IDENTIFIER = org.apache.tvmrpcLauncher; - PRODUCT_NAME = "$(TARGET_NAME)"; - TEST_HOST = "$(BUILT_PRODUCTS_DIR)/tvmrpc.app/tvmrpc"; - }; - name = Release; - }; /* End XCBuildConfiguration section */ /* Begin XCConfigurationList section */ @@ -542,15 +442,6 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - C05A2C8F1F1DCE0900D4798B /* Build configuration list for PBXNativeTarget "tvmrpcLauncher" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - C05A2C8D1F1DCE0900D4798B /* Debug */, - C05A2C8E1F1DCE0900D4798B /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; /* End XCConfigurationList section */ }; rootObject = C02637431F1C25E8007247A9 /* Project object */; diff --git a/apps/ios_rpc/tvmrpc.xcodeproj/xcshareddata/xcschemes/tvmrpc.xcscheme b/apps/ios_rpc/tvmrpc.xcodeproj/xcshareddata/xcschemes/tvmrpc.xcscheme new file mode 100644 index 0000000000000..42a53a2336a7f --- /dev/null +++ b/apps/ios_rpc/tvmrpc.xcodeproj/xcshareddata/xcschemes/tvmrpc.xcscheme @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/apps/ios_rpc/tvmrpc/Base.lproj/Main.storyboard b/apps/ios_rpc/tvmrpc/Base.lproj/Main.storyboard index 2356abf6ca4f1..3ae0c28b43899 100644 --- a/apps/ios_rpc/tvmrpc/Base.lproj/Main.storyboard +++ b/apps/ios_rpc/tvmrpc/Base.lproj/Main.storyboard @@ -16,13 +16,11 @@ - - - - + + - + @@ -35,33 +33,21 @@ - + - - - + + diff --git a/apps/ios_rpc/tvmrpc/RPCArgs.h b/apps/ios_rpc/tvmrpc/RPCArgs.h new file mode 100644 index 0000000000000..e5a1ee47a019f --- /dev/null +++ b/apps/ios_rpc/tvmrpc/RPCArgs.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_APPS_IOS_RPC_ARGS_H_ +#define TVM_APPS_IOS_RPC_ARGS_H_ + +#import "RPCServer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief Struct representing arguments of iOS RPC app + */ +typedef struct RPCArgs_t { + /// Tracker or Proxy address (actually ip) + const char* host_url; + + /// Tracker or Proxy port + int host_port; + + /// device key to report + const char* key; + + /// custom adress to report into Tracker. Ignored for other server modes. + const char* custom_addr; + + /// Verbose mode. Will print status messages to std out. + /// 0 - no prints , 1 - print state to output + bool verbose; + + /// Immediate server launch. No UI interaction. + /// 0 - UI interaction, 1 - automatically connect on launch + bool immediate_connect; + + /// Server mode + RPCServerMode server_mode; +} RPCArgs; + +/*! + * \brief Get current global RPC args + */ +RPCArgs get_current_rpc_args(void); + +/*! + * \brief Set current global RPC args and update values in app cache + */ +void set_current_rpc_args(RPCArgs args); + +/*! + * \brief Pars command line args and update current global RPC args + * Also updates values in app cache + */ +void update_rpc_args(int argc, char* argv[]); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_APPS_IOS_RPC_ARGS_H_ diff --git a/apps/ios_rpc/tvmrpc/RPCArgs.mm b/apps/ios_rpc/tvmrpc/RPCArgs.mm new file mode 100644 index 0000000000000..7f5d68d7dde55 --- /dev/null +++ b/apps/ios_rpc/tvmrpc/RPCArgs.mm @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "RPCArgs.h" + +#import + +#import "../../../src/support/socket.h" +#import "../../../src/support/utils.h" + +#import + +using std::string; + +const char* kUsage = + "\n" + "iOS tvmrpc application supported flags:\n" + "--host_url The tracker/proxy address, Default=0.0.0.0\n" + "--host_port The tracker/proxy port, Default=9190\n" + "--key The key used to identify the device type in tracker. Default=\"\"\n" + "--custom_addr Custom IP Address to Report to RPC Tracker. Default=\"\"\n" + "--immediate_connect No UI interconnection, connect to tracker immediately. Default=False\n" + "--verbose Allow to print status info to std out. Default=False\n" + "--server_mode Server mode. Can be \"standalone\", \"proxy\" or \"tracker\". " + "Default=standalone \n" + "\n"; + +struct RPCArgs_cpp { + string host_url = "0.0.0.0"; + int host_port = 9190; + + string key; + string custom_addr = "null"; + + bool immediate_connect = false; + bool verbose = false; + RPCServerMode server_mode = RPCServerMode_Tracker; + + operator RPCArgs() const { + return RPCArgs{.host_url = host_url.c_str(), + .host_port = host_port, + .key = key.c_str(), + .custom_addr = custom_addr.c_str(), + .verbose = verbose, + .immediate_connect = immediate_connect, + .server_mode = server_mode}; + }; + + RPCArgs_cpp& operator=(const RPCArgs& args) { + host_url = args.host_url; + host_port = args.host_port; + key = args.key; + custom_addr = args.custom_addr; + verbose = args.verbose; + immediate_connect = args.immediate_connect; + server_mode = args.server_mode; + return *this; + } +}; + +struct RPCArgs_cpp g_rpc_args; + +static void restore_from_cache() { + NSUserDefaults* defaults = [NSUserDefaults standardUserDefaults]; + + auto get_string_from_cache = [defaults](const char* key) { + NSString* ns_key = [NSString stringWithUTF8String:key]; + NSString* ns_val = [defaults stringForKey:ns_key]; + return std::string(ns_val != nil ? [ns_val UTF8String] : ""); + }; + + auto get_int_from_cache = [defaults](const char* key) { + NSString* ns_key = [NSString stringWithUTF8String:key]; + return static_cast([defaults integerForKey:ns_key]); + }; + + g_rpc_args.host_url = get_string_from_cache("RPCArgs_url"); + g_rpc_args.host_port = get_int_from_cache("RPCArgs_port"); + g_rpc_args.key = get_string_from_cache("RPCArgs_key"); +} + +static void update_in_cache() { + NSUserDefaults* defaults = [NSUserDefaults standardUserDefaults]; + + [defaults setObject:[NSString stringWithUTF8String:g_rpc_args.host_url.c_str()] + forKey:@"RPCArgs_url"]; + [defaults setInteger:g_rpc_args.host_port forKey:@"RPCArgs_port"]; + [defaults setObject:[NSString stringWithUTF8String:g_rpc_args.key.c_str()] forKey:@"RPCArgs_key"]; +} + +string GetCmdOption(int argc, char* argv[], string option, bool key = false) { + string cmd; + for (int i = 1; i < argc; ++i) { + string arg = argv[i]; + if (arg.find(option) == 0) { + if (key) { + cmd = argv[i]; + return cmd; + } + // We assume "=" is the end of option. + ICHECK_EQ(*option.rbegin(), '='); + cmd = arg.substr(arg.find('=') + 1); + return cmd; + } + } + return cmd; +} + +void update_rpc_args(int argc, char* argv[]) { + restore_from_cache(); + RPCArgs_cpp& args = g_rpc_args; + + using tvm::support::IsNumber; + using tvm::support::ValidateIP; + constexpr int MAX_PORT_NUM = 65535; + + const string immediate_connect = GetCmdOption(argc, argv, "--immediate_connect", true); + args.immediate_connect = !immediate_connect.empty(); + + const string verbose = GetCmdOption(argc, argv, "--verbose", true); + args.verbose = !verbose.empty(); + + const string server_mode = GetCmdOption(argc, argv, "--server_mode=", false); + if (!server_mode.empty()) { + if (server_mode == "tracker") { + args.server_mode = RPCServerMode_Tracker; + } else if (server_mode == "proxy") { + args.server_mode = RPCServerMode_Proxy; + } else if (server_mode == "standalone") { + args.server_mode = RPCServerMode_Standalone; + } else { + LOG(WARNING) << "Wrong server_mode value."; + LOG(INFO) << kUsage; + exit(1); + } + } + + const string host_url = GetCmdOption(argc, argv, "--host_url="); + if (!host_url.empty()) { + if (!ValidateIP(host_url)) { + LOG(WARNING) << "Wrong tracker address format."; + LOG(INFO) << kUsage; + exit(1); + } + args.host_url = host_url; + } + + const string host_port = GetCmdOption(argc, argv, "--host_port="); + if (!host_port.empty()) { + if (!IsNumber(host_port) || stoi(host_port) > MAX_PORT_NUM) { + LOG(WARNING) << "Wrong trackerport number."; + LOG(INFO) << kUsage; + exit(1); + } + args.host_port = stoi(host_port); + } + + const string key = GetCmdOption(argc, argv, "--key="); + if (!key.empty()) { + args.key = key; + } + + const string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + if (!custom_addr.empty()) { + if (!ValidateIP(custom_addr)) { + LOG(WARNING) << "Wrong custom address format."; + LOG(INFO) << kUsage; + exit(1); + } + args.custom_addr = '"' + custom_addr + '"'; + } + + update_in_cache(); +} + +RPCArgs get_current_rpc_args(void) { return g_rpc_args; } + +void set_current_rpc_args(RPCArgs args) { + g_rpc_args = args; + update_in_cache(); +} diff --git a/apps/ios_rpc/tvmrpc/RPCServer.h b/apps/ios_rpc/tvmrpc/RPCServer.h new file mode 100644 index 0000000000000..dcb0d348df9e0 --- /dev/null +++ b/apps/ios_rpc/tvmrpc/RPCServer.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file Provide interfaces to launch and control RPC Service routine + */ + +#import + +/*! + * \brief Enum with possible status of RPC server + * Used to report state to listener + */ +typedef enum { + RPCServerStatus_Launched, // Worker thread is launched + RPCServerStatus_Stopped, // Worker thread stopped + RPCServerStatus_Connected, // Connected to Proxy/Tracker + RPCServerStatus_Disconnected, // Disconnected from Proxy/Tracker + RPCServerStatus_RPCSessionStarted, // RPC session is started + RPCServerStatus_RPCSessionFinished // RPC session is finished +} RPCServerStatus; + +/*! + * \brief Enum with modes of servicing supported by RPCServer + */ +typedef enum { + /// Tracker mode. Same as Standalone Server plus register it into Tracker. + RPCServerMode_Tracker, + /// Proxy mode. Connect to proxy server and wait response. + RPCServerMode_Proxy, + /// Standalone RPC server mode. Open port with RPC server and wait incoming connection. + RPCServerMode_Standalone +} RPCServerMode; + +/*! + * \brief Listener for events happened with RPCServer + */ +@protocol RPCServerEventListener +/// Callback to notifying about new status +- (void)onError:(NSString*)msg; +/// Callback to notifying about error +- (void)onStatusChanged:(RPCServerStatus)status; +@end + +/*! + * \brief RPC Server instance + * Contains internal worker thread plus + */ +@interface RPCServer : NSObject + +/// Event listener delegate to set +@property(retain) id delegate; +/// Device key to report during RPC session +@property(retain) NSString* key; +/// Host address of Proxy/Tracker server (generally IPv4). Ignored for Standalone mode. +@property(retain) NSString* host; +/// Port of Proxy/Tracker server. Ignored for Standalone mode. +@property int port; +/// Custom address to report into tracker server (optional). Ignored for Standalone/Proxy modes +@property(retain) NSString* custom_addr; +/// Triger to enable printing of server state info +@property BOOL verbose; +/// RPC port opened on the device. Ignored for Proxy/Tracker modes +@property int actual_port; +/// IP address of the device. Ignored for Proxy/Tracker modes +@property(retain) NSString* device_addr; + +/*! + * \brief Create server with specified sevicing mode + * \param mode Mode of server + */ ++ (instancetype)serverWithMode:(RPCServerMode)mode; + +/*! + * \brief Start RPC server with options. Non blocking method + */ +- (void)start; + +/*! + * \brief Stop RPC server. Non blocking method + */ +- (void)stop; + +@end diff --git a/apps/ios_rpc/tvmrpc/RPCServer.mm b/apps/ios_rpc/tvmrpc/RPCServer.mm new file mode 100644 index 0000000000000..b381d1afd1fff --- /dev/null +++ b/apps/ios_rpc/tvmrpc/RPCServer.mm @@ -0,0 +1,818 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ViewController.mm + */ + +#import "RPCServer.h" + +#include +#include + +#include +#include + +// To get device WiFi IP +#include +#include +#include +#include + +// TVM internal header to access Magic keys like kRPCMagic and others +#include "../../../src/runtime/rpc/rpc_endpoint.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Message handling function for event driven server. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ +using FEventHandler = PackedFunc; + +/*! + * \brief Create a server event handler. + * + * \param outputStream The output stream used to send outputs. + * \param name The name of the server. + * \param remote_key The remote key + * \return The event handler. + */ +FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, + std::string remote_key) { + const PackedFunc* event_handler_factory = Registry::Get("rpc.CreateEventDrivenServer"); + ICHECK(event_handler_factory != nullptr) + << "You are using tvm_runtime module built without RPC support. " + << "Please rebuild it with USE_RPC flag."; + + PackedFunc writer_func([outputStream](TVMArgs args, TVMRetValue* rv) { + TVMByteArray* data = args[0].ptr(); + int64_t nbytes = [outputStream write:reinterpret_cast(data->data) + maxLength:data->size]; + if (nbytes < 0) { + NSLog(@"%@", [outputStream streamError].localizedDescription); + throw tvm::Error("Stream error"); + } + *rv = nbytes; + }); + + return (*event_handler_factory)(writer_func, name, remote_key); +} + +/*! + * \brief Helper function to query real IP of device in WiFi network + * \return string with IPv4 in format "192.168.0.1" or "unknown" if cannot detect + */ +static std::string getWiFiAddress() { + std::string address = "unknown"; + ifaddrs* interfaces = nullptr; + + int success = getifaddrs(&interfaces); + if (success == 0) { + ifaddrs* temp_addr = interfaces; + while (temp_addr != NULL) { + if (temp_addr->ifa_addr->sa_family == AF_INET) { + // Check if interface is en0 which is the wifi connection on the iPhone + if (std::string(temp_addr->ifa_name) == "en0") { + address = inet_ntoa(((sockaddr_in*)temp_addr->ifa_addr)->sin_addr); + } + } + temp_addr = temp_addr->ifa_next; + } + } + + freeifaddrs(interfaces); + return address; +} + +} // namespace runtime +} // namespace tvm + +// Base class for any type of RPC servicing +@interface RPCServerBase : RPCServer + +/*! + * Methods should be implemented in inherited classes + */ +- (bool)onReadHandler; // return true - continue feeding, false - stop, try to drain output buffer +- (bool)onWriteHandler; // return true - continue draining, false - no data to write +- (void)onEndEncountered; // called on disconnect or session desided that it's shutdown time +- (void)open; // Initiate listening objects like i/o streams and other resources +- (void)close; // Deinitialize resources opend in "open" method +@end + +@implementation RPCServerBase { + // Worker thread + NSThread* worker_thread_; + // Triger to continue RunLoop processing inside worker_thread_ + BOOL shouldKeepRunning; + // Input socket stream + @protected + NSInputStream* inputStream_; + // Output socket stream + NSOutputStream* outputStream_; + // Temporal buffer with data to send + std::string sendBuffer_; + // Temporal receive buffer + std::string recvBuffer_; + // Requested data size to accumulate in recvBuffer_ before continue processing + int requiredToRecv_; +} + +/*! + * Start internal worker thread with RunLoop and submit correspoding open handlers into it + * Not blocking + */ +- (void)start { + worker_thread_ = [[NSThread alloc] initWithBlock:^{ + @autoreleasepool { + [self notifyState:RPCServerStatus_Launched]; + [self open]; + shouldKeepRunning = YES; + while (shouldKeepRunning && [[NSRunLoop currentRunLoop] runMode:NSDefaultRunLoopMode + beforeDate:[NSDate distantFuture]]) + ; + [self notifyState:RPCServerStatus_Stopped]; + } + }]; + [worker_thread_ start]; +} + +/*! + * Send message to workel thread runloop to finish processing + * Not blocking + */ +- (void)stop { + if (worker_thread_ == nil) return; + + [self performSelector:@selector(stop_) onThread:worker_thread_ withObject:nil waitUntilDone:NO]; + worker_thread_ = nil; // TODO: is it valide? may be better to do that inside NSThread? +} + +- (void)stop_ { + [self close]; + shouldKeepRunning = NO; +} + +/*! + * Base implementation to selup i/o streams + * Will connect to host and port specified in corresponding properties + */ +- (void)open { + CFReadStreamRef readStream; + CFWriteStreamRef writeStream; + CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.host, self.port, &readStream, + &writeStream); + inputStream_ = (__bridge NSInputStream*)readStream; + outputStream_ = (__bridge NSOutputStream*)writeStream; + [inputStream_ setDelegate:self]; + [outputStream_ setDelegate:self]; + [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [outputStream_ open]; + [inputStream_ open]; +} + +/*! + * Base implementation to selup i/o streams + * Will assign i/o streams to provided socket connection. + */ +- (void)openWithSocket:(CFSocketNativeHandle)sock { + CFReadStreamRef readStream; + CFWriteStreamRef writeStream; + CFStreamCreatePairWithSocket(NULL, sock, &readStream, &writeStream); + inputStream_ = (__bridge NSInputStream*)readStream; + outputStream_ = (__bridge NSOutputStream*)writeStream; + [inputStream_ setDelegate:self]; + [outputStream_ setDelegate:self]; + [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [outputStream_ open]; + [inputStream_ open]; +} + +/*! + * Close i/o streams assosiated with connection + */ +- (void)close { + [inputStream_ close]; + [outputStream_ close]; + [inputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [outputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; + [inputStream_ setDelegate:nil]; + [outputStream_ setDelegate:nil]; + inputStream_ = nil; + outputStream_ = nil; +} + +/// Unimplemented stubs +- (bool)onReadHandler { + return false; +} +- (bool)onWriteHandler { + return false; +} +- (void)onEndEncountered { +} + +/*! + * Try to read data from stream and call processing hadnler + */ +- (void)tryToRead { + const int kBufferSize = 4 << 10; // 4kB buffer + const int prev_size = recvBuffer_.size(); + recvBuffer_.resize(kBufferSize); + size_t nbytes = [inputStream_ read:(uint8_t*)recvBuffer_.data() + prev_size + maxLength:recvBuffer_.size() - prev_size]; + recvBuffer_.resize(nbytes + prev_size); + + // feed while it accept or requested particulat buffer size + while (!recvBuffer_.empty() && requiredToRecv_ <= recvBuffer_.size() && [self onReadHandler]) + ; +} + +/*! + * Try to write remaining data to stream and call processing hadnler + */ +- (void)tryToWrite { + if (!sendBuffer_.empty()) { + size_t nbytes = [outputStream_ write:(uint8_t*)sendBuffer_.data() maxLength:sendBuffer_.size()]; + sendBuffer_.erase(0, nbytes); + } + // call write handler while it want be called and space is available + while (sendBuffer_.empty() && [outputStream_ hasSpaceAvailable] && [self onWriteHandler]) + ; +} + +/*! + * Main event handler of socket stream events + */ +- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event { + std::string buffer; + switch (event) { + case NSStreamEventOpenCompleted: { + // Nothing + break; + } + case NSStreamEventHasBytesAvailable: + if (strm == inputStream_) { + [self tryToRead]; + if ([outputStream_ hasSpaceAvailable]) [self tryToWrite]; + } + break; + case NSStreamEventHasSpaceAvailable: { + if (strm == outputStream_) { + [self tryToWrite]; + if ([inputStream_ hasBytesAvailable]) [self tryToRead]; + } + break; + } + case NSStreamEventErrorOccurred: { + [self notifyError:[strm streamError].localizedDescription]; + break; + } + case NSStreamEventEndEncountered: { + [self onEndEncountered]; + break; + } + default: { + NSLog(@"Unknown event"); + } + } +} + +#pragma mark - Helpers + +/*! + * Set buffer to send into stream. Try to send immediatly or submit to lazy sending + * Non blocking operation + */ +- (void)toSend:(NSData*)data { + sendBuffer_.append(static_cast(data.bytes), data.length); + + // try to flush buffer + NSInteger sent_size = [outputStream_ write:(uint8_t*)sendBuffer_.data() + maxLength:sendBuffer_.size()]; + sendBuffer_.erase(0, sent_size); +} + +/*! + * Set buffer to send in packet format [size, data]. Behaviour is same as for toSend. + */ +- (void)toSendPacked:(NSData*)data { + uint32_t packet_size = data.length; + [self toSend:[NSData dataWithBytes:&packet_size length:sizeof(packet_size)]]; + [self toSend:data]; +} + +/*! + */ +- (NSData*)requestInputDataWithSize:(NSInteger)size { + if (recvBuffer_.size() < size) { + requiredToRecv_ = size; + return nil; + } + NSData* res = [NSData dataWithBytes:recvBuffer_.data() length:size]; + recvBuffer_.erase(0, size); + return res; +} + +/*! + */ +- (NSData*)requestInputDataPacked { + uint32_t size; + if (recvBuffer_.size() < sizeof(size)) { + requiredToRecv_ = sizeof(size); + return nil; + } + size = *(uint32_t*)recvBuffer_.data(); + if (recvBuffer_.size() < sizeof(size) + size) { + requiredToRecv_ = sizeof(size) + size; + return nil; + } + NSData* res = [NSData dataWithBytes:recvBuffer_.data() + sizeof(size) length:size]; + recvBuffer_.erase(0, sizeof(size) + size); + return res; +}; + +#pragma mark - Notifiers + +/*! + * Notify external listener about error. + * Also print error message to std out in case of Verbose mode + */ +- (void)notifyError:(NSString*)msg { + // Duplicate error message in std output. Host launcher script may listen it. + if (self.verbose) NSLog(@"[IOS-RPC] ERROR: %@", msg); + if (self.delegate) [self.delegate onError:msg]; +} + +/*! + * Notify external listener about server state changes. + * Also print information to std out in case of Verbose mode + */ +- (void)notifyState:(RPCServerStatus)state { + // Duplicate sattus changing in std output. Host launcher script may listen it. + if (self.verbose) NSLog(@"[IOS-RPC] STATE: %d", state); + if (self.delegate != nil) [self.delegate onStatusChanged:state]; +} + +@end + +@interface RPCServerProxy : RPCServerBase +@end + +typedef enum { + RPCServerProxyState_Idle, + RPCServerProxyState_HandshakeToSend, + RPCServerProxyState_HandshakeToRecv, + RPCServerProxyState_Processing, +} RPCServerProxyState; + +@implementation RPCServerProxy { + /// Original TVM RPC event handler + tvm::runtime::FEventHandler handler_; + @protected + /// Sate of Proxy client implementation + RPCServerProxyState state_; +} + +- (instancetype)init { + if (self = [super init]) { + handler_ = nullptr; + state_ = RPCServerProxyState_Idle; + } + return self; +} + +/*! + * Implement matching of internat state on state available for outside users + */ +- (void)setState:(RPCServerProxyState)new_state { + // Send Connected notification because Proxy doesn't responce until client connected. + if (new_state == RPCServerProxyState_HandshakeToRecv) + [self notifyState:RPCServerStatus_Connected]; + if (new_state == RPCServerProxyState_Idle) [self notifyState:RPCServerStatus_Disconnected]; + if (state_ == RPCServerProxyState_HandshakeToRecv && new_state == RPCServerProxyState_Processing) + [self notifyState:RPCServerStatus_RPCSessionStarted]; + if (state_ == RPCServerProxyState_Processing && new_state == RPCServerProxyState_Idle) + [self notifyState:RPCServerStatus_RPCSessionStarted]; + + state_ = new_state; +} + +- (bool)onWriteHandler { + switch (state_) { + case RPCServerProxyState_HandshakeToSend: { + // Send together kRPCMagic and server descriptor because of Proxy + int32_t code = tvm::runtime::kRPCMagic; + [self toSend:[NSData dataWithBytes:&code length:sizeof(code)]]; + + std::string full_key = std::string("server:") + self.key.UTF8String; + [self toSendPacked:[NSData dataWithBytes:full_key.data() length:full_key.size()]]; + + self.state = RPCServerProxyState_HandshakeToRecv; + return TRUE; + } + case RPCServerProxyState_Processing: { + try { + TVMByteArray dummy{nullptr, 0}; + int flag = handler_(dummy, 2); + if (flag == 0) { + [self onEndEncountered]; + } + return flag == 2; + } catch (const tvm::Error& e) { + [self close]; + } + break; + } + default: + // Nothing + break; + } + return FALSE; +} + +- (bool)onReadHandler { + switch (state_) { + case RPCServerProxyState_HandshakeToRecv: { + int32_t code = tvm::runtime::kRPCMagic; + NSData* data = [self requestInputDataWithSize:sizeof(code)]; + if (data == nil) return FALSE; + + if (*(int32_t*)data.bytes != tvm::runtime::kRPCMagic) { + [self notifyError:@"Wrong responce, is not RPC client."]; + [self close]; + return FALSE; + break; + } + + handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, "iphone", "%toinit"); + + self.state = RPCServerProxyState_Processing; + return TRUE; + break; + } + case RPCServerProxyState_Processing: { + int flag = 1; + if ([outputStream_ hasSpaceAvailable]) { + flag |= 2; + } + // always try to write + try { + TVMByteArray arr{recvBuffer_.data(), recvBuffer_.size()}; + flag = handler_(arr, flag); + recvBuffer_.clear(); + if (flag == 0) { + [self onEndEncountered]; + } + return flag == 1; + } catch (const tvm::Error& e) { + [self close]; + } + break; + } + default: + // Nothing + break; + } + return FALSE; +} + +- (void)onEndEncountered { + // Automatic reconnection when session is finished. + [self close]; + [self open]; +} + +- (void)open { + [super open]; + self.state = RPCServerProxyState_HandshakeToSend; +} + +- (void)close { + [super close]; + handler_ = nullptr; + self.state = RPCServerProxyState_Idle; +} + +@end + +@interface RPCServerStandalone : RPCServerProxy +@property(readonly) int rpc_port; +@end + +@implementation RPCServerStandalone { + // Socket to listen incoming connections + CFSocketRef socket_; + /// Current socket connection handler + CFSocketNativeHandle connection_; + /// Port range to try bind to socket + int port_range_start; + int port_range_end; +} + +- (instancetype)init { + if (self = [super init]) { + connection_ = 0; + port_range_start = 9090; + port_range_end = 9099; + } + return self; +} + +- (void)setState:(RPCServerProxyState)new_state { + if (state_ == RPCServerProxyState_Idle && new_state == RPCServerProxyState_HandshakeToSend) { + self.actual_port = _rpc_port; + self.device_addr = [NSString stringWithUTF8String:tvm::runtime::getWiFiAddress().c_str()]; + if (self.verbose) { + // Notify host runner script with actual address + NSLog(@"[IOS-RPC] IP: %s", tvm::runtime::getWiFiAddress().c_str()); + NSLog(@"[IOS-RPC] PORT: %d", _rpc_port); + } + [self notifyState:RPCServerStatus_Connected]; + } + if (new_state == RPCServerProxyState_Idle) [self notifyState:RPCServerStatus_Disconnected]; + if (state_ == RPCServerProxyState_HandshakeToRecv && new_state == RPCServerProxyState_Processing) + [self notifyState:RPCServerStatus_RPCSessionStarted]; + if (state_ == RPCServerProxyState_Processing && new_state == RPCServerProxyState_HandshakeToSend) + [self notifyState:RPCServerStatus_RPCSessionFinished]; + + state_ = new_state; +} + +- (void)handleConnect:(CFSocketNativeHandle)hdl { + connection_ = hdl; + [super openWithSocket:connection_]; + self.state = RPCServerProxyState_HandshakeToSend; +} + +static void handleConnect(CFSocketRef socket, CFSocketCallBackType type, CFDataRef address, + const void* data, void* info) { + RPCServerStandalone* it = (__bridge RPCServerStandalone*)(info); + [it handleConnect:*static_cast(data)]; +} + +- (void)open { + CFSocketContext ctx{}; + ctx.info = (__bridge void*)self; + + socket_ = CFSocketCreate(kCFAllocatorDefault, PF_INET, SOCK_STREAM, IPPROTO_TCP, + kCFSocketAcceptCallBack, handleConnect, &ctx); + self->_rpc_port = 0; + + // Try to bind with range + for (int port = port_range_start; port < port_range_end; port++) { + struct sockaddr_in sin; + memset(&sin, 0, sizeof(sin)); + sin.sin_len = sizeof(sin); + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = INADDR_ANY; + + CFDataRef sincfd = CFDataCreate(kCFAllocatorDefault, (UInt8*)&sin, sizeof(sin)); + CFSocketError res = CFSocketSetAddress(socket_, sincfd); + CFRelease(sincfd); + if (res == kCFSocketSuccess) { + self->_rpc_port = port; + break; + } + } + if (self->_rpc_port == 0) { + @throw + [NSException exceptionWithName:@"SocketError" + reason:[NSString stringWithFormat:@"Unable bind socket to port" + "in range [%d, %d]", + port_range_start, port_range_end] + userInfo:nil]; + } + + CFRunLoopSourceRef socketsource = CFSocketCreateRunLoopSource(kCFAllocatorDefault, socket_, 0); + CFRunLoopAddSource(CFRunLoopGetCurrent(), socketsource, kCFRunLoopDefaultMode); + + self.state = RPCServerProxyState_HandshakeToSend; +} + +- (void)closeSocket { + CFSocketInvalidate(socket_); +} + +- (void)close { + [super close]; + close(connection_); +} + +- (void)onEndEncountered { + [self close]; + [self notifyState:RPCServerStatus_RPCSessionFinished]; +} + +@end + +@interface RPCServerTracker : RPCServerBase +@end + +typedef enum { + RPCServerTracker_Idle, + RPCServerTracker_HandshakeToSend, + RPCServerTracker_HandshakeToRecv, + RPCServerTracker_ServerInfoToSend, + RPCServerTracker_ServerInfoToRecv, + RPCServerTracker_ReportResToSend, + RPCServerTracker_ReportResToRecv, + RPCServerTracker_UpdateKeyToSend, + RPCServerTracker_UpdateKeyToRecv, + RPCServerTracker_WaitConnection +} RPCServerTrackerState; + +@implementation RPCServerTracker { + RPCServerTrackerState state_; + RPCServerStandalone* rpc_server_; +} + +- (void)setState:(RPCServerTrackerState)new_state { + if (state_ == RPCServerTracker_ReportResToRecv && new_state == RPCServerTracker_WaitConnection) + [self notifyState:RPCServerStatus_Connected]; + if (state_ == RPCServerTracker_WaitConnection && new_state == RPCServerTracker_Idle) + [self notifyState:RPCServerStatus_Disconnected]; + + state_ = new_state; +} + +- (bool)onWriteHandler { + switch (state_) { + case RPCServerTracker_HandshakeToSend: { + int32_t code = tvm::runtime::kRPCTrackerMagic; + [self toSend:[NSData dataWithBytes:&code length:sizeof(code)]]; + self.state = RPCServerTracker_HandshakeToRecv; + return TRUE; + break; + } + case RPCServerTracker_ServerInfoToSend: { + std::ostringstream ss; + ss << "[" << static_cast(tvm::runtime::TrackerCode::kUpdateInfo) + << ", {\"key\": \"server:" << self.key.UTF8String << "\", \"addr\": [" + << self.custom_addr.UTF8String << ", \"" << self.port << "\"]}]"; + std::string data_s = ss.str(); + [self toSendPacked:[NSData dataWithBytes:data_s.data() length:data_s.length()]]; + self.state = RPCServerTracker_ServerInfoToRecv; + return TRUE; + break; + } + case RPCServerTracker_ReportResToSend: { + std::mt19937 gen(std::random_device{}()); + std::uniform_real_distribution dis(0.0, 1.0); + + std::string address_to_report = "null"; + if (self.custom_addr != nil && self.custom_addr.length != 0) { + address_to_report = self.custom_addr.UTF8String; + } + + std::string matchkey = std::string(self.key.UTF8String) + ":" + std::to_string(dis(gen)); + std::ostringstream ss; + ss << "[" << static_cast(tvm::runtime::TrackerCode::kPut) << ", \"" + << self.key.UTF8String << "\", [" << rpc_server_.rpc_port << ", \"" << matchkey << "\"], " + << address_to_report << "]"; + + std::string data_s = ss.str(); + [self toSendPacked:[NSData dataWithBytes:data_s.data() length:data_s.length()]]; + self.state = RPCServerTracker_ReportResToRecv; + return TRUE; + break; + } + default: + // Nothing + break; + } + return FALSE; +} + +- (bool)onReadHandler { + static const std::string resp_OK = + std::to_string(static_cast(tvm::runtime::TrackerCode::kSuccess)); + switch (state_) { + case RPCServerTracker_HandshakeToRecv: { + NSData* data = [self requestInputDataWithSize:sizeof(int)]; + if (data == nil) return FALSE; + + if (*(int*)data.bytes != tvm::runtime::kRPCTrackerMagic) { + [self notifyError:@"Wrong responce, is not RPC Tracker."]; + [self close]; + return FALSE; + break; + } + self.state = RPCServerTracker_ServerInfoToSend; + return TRUE; + break; + } + case RPCServerTracker_ServerInfoToRecv: { + NSData* data = [self requestInputDataPacked]; + if (data == nil) return FALSE; + + if (std::string((char*)data.bytes, data.length) != resp_OK) { + [self notifyError:@"Failed to Update info on tracker. Responce is not OK."]; + [self close]; + return FALSE; + break; + } + self.state = RPCServerTracker_ReportResToSend; + return TRUE; + break; + } + case RPCServerTracker_ReportResToRecv: { + NSData* data = [self requestInputDataPacked]; + if (data == nil) return FALSE; + + if (std::string((char*)data.bytes, data.length) != resp_OK) { + [self notifyError:@"Failed to Put server into tracker. Responce is not OK."]; + [self close]; + return FALSE; + break; + } + self.state = RPCServerTracker_WaitConnection; + return TRUE; + break; + } + default: + // Nothing + break; + } + return FALSE; +} + +- (void)onEndEncountered { + [self close]; +} + +- (void)close { + [rpc_server_ close]; + [rpc_server_ closeSocket]; + [super close]; + self.state = RPCServerTracker_Idle; +} + +- (void)open { + // Start internal Standalone RPC server at first + rpc_server_ = [[RPCServerStandalone alloc] init]; + rpc_server_.key = self.key; + rpc_server_.delegate = self; + [rpc_server_ open]; + + [super open]; + self.state = RPCServerTracker_HandshakeToSend; +} + +- (void)onError:(NSString*)msg { + // transfer error form internal rpc_server_ to real delegate + [self notifyError:msg]; +} + +- (void)onStatusChanged:(RPCServerStatus)status { + if (status == RPCServerStatus_RPCSessionFinished) { + [self notifyState:status]; + self.state = RPCServerTracker_ReportResToSend; + [self tryToWrite]; + } +} +@end + +@implementation RPCServer + ++ (instancetype)serverWithMode:(RPCServerMode)mode { + if (mode == RPCServerMode_Standalone) return [[RPCServerStandalone alloc] init]; + if (mode == RPCServerMode_Proxy) return [[RPCServerProxy alloc] init]; + if (mode == RPCServerMode_Tracker) return [[RPCServerTracker alloc] init]; + return nil; +} + +/// Unimplemented stubs +- (void)start { +} +- (void)stop { +} + +@end diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.h b/apps/ios_rpc/tvmrpc/TVMRuntime.h deleted file mode 100644 index 0d172fc3eaa11..0000000000000 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file TVMRuntime.h - */ -#import -// Customize logging mechanism, redirect to NSLOG -#define TVM_LOG_CUSTOMIZE 1 -#define TVM_METAL_RUNTIME 1 - -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief Message handling function for event driven server. - * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ -using FEventHandler = std::function; - -/*! - * \brief Create a server event handler. - * - * \param outputStream The output stream used to send outputs. - * \param name The name of the server. - * \param remote_key The remote key - * \return The event handler. - */ -FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, - std::string remote_key); - -} // namespace runtime -} // namespace tvm - -@interface TVMRuntime : NSObject - -+ (void)launchSyncServer; - -@end diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index 8950eb4eab1d8..09a1a17ffd379 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -20,50 +20,26 @@ /*! * \file TVMRuntime.mm */ -#include "TVMRuntime.h" -// Runtime API -#include "../../../src/runtime/c_runtime_api.cc" -#include "../../../src/runtime/contrib/random/random.cc" -#include "../../../src/runtime/cpu_device_api.cc" -#include "../../../src/runtime/dso_library.cc" -#include "../../../src/runtime/file_utils.cc" -#include "../../../src/runtime/library_module.cc" -#include "../../../src/runtime/logging.cc" -#include "../../../src/runtime/metadata_module.cc" -#include "../../../src/runtime/module.cc" -#include "../../../src/runtime/ndarray.cc" -#include "../../../src/runtime/object.cc" -#include "../../../src/runtime/profiling.cc" -#include "../../../src/runtime/registry.cc" -#include "../../../src/runtime/source_utils.cc" -#include "../../../src/runtime/system_library.cc" -#include "../../../src/runtime/thread_pool.cc" -#include "../../../src/runtime/threading_backend.cc" -#include "../../../src/runtime/workspace_pool.cc" - -// RPC server -#include "../../../src/runtime/rpc/rpc_channel.cc" -#include "../../../src/runtime/rpc/rpc_endpoint.cc" -#include "../../../src/runtime/rpc/rpc_local_session.cc" -#include "../../../src/runtime/rpc/rpc_module.cc" -#include "../../../src/runtime/rpc/rpc_server_env.cc" -#include "../../../src/runtime/rpc/rpc_session.cc" -#include "../../../src/runtime/rpc/rpc_socket_impl.cc" -// Graph executor -#include "../../../src/runtime/graph_executor/graph_executor.cc" -// Metal -#include "../../../src/runtime/metal/metal_device_api.mm" -#include "../../../src/runtime/metal/metal_module.mm" -// CoreML -#include "../../../src/runtime/contrib/coreml/coreml_runtime.mm" + +#import + +#include + +#include "RPCArgs.h" + +// internal TVM header +#include <../../../src/runtime/file_utils.h> #if defined(USE_CUSTOM_DSO_LOADER) && USE_CUSTOM_DSO_LOADER == 1 +// internal TVM header to achive Library class +#include <../../../src/runtime/library_module.h> #include #endif namespace tvm { namespace runtime { namespace detail { + // Override logging mechanism void LogFatalImpl(const std::string& file, int lineno, const std::string& message) { throw tvm::runtime::InternalError(file, lineno, message); @@ -72,77 +48,13 @@ void LogFatalImpl(const std::string& file, int lineno, const std::string& messag void LogMessageImpl(const std::string& file, int lineno, const std::string& message) { NSLog(@"%s:%d: %s", file.c_str(), lineno, message.c_str()); } -} -} -} // namespace dmlc - -namespace tvm { -namespace runtime { - -class NSStreamChannel final : public RPCChannel { - public: - explicit NSStreamChannel(NSOutputStream* stream) : stream_(stream) {} - - size_t Send(const void* data, size_t size) final { - ssize_t nbytes = [stream_ write:reinterpret_cast(data) maxLength:size]; - if (nbytes < 0) { - NSLog(@"%@", [stream_ streamError].localizedDescription); - throw tvm::Error("Stream error"); - } - return nbytes; - } - - size_t Recv(void* data, size_t size) final { - LOG(FATAL) << "Do not allow explicit receive for"; - return 0; - } - - private: - NSOutputStream* stream_; -}; - -FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, - std::string remote_key) { - std::unique_ptr ch(new NSStreamChannel(outputStream)); - std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); - return [sess](const std::string& in_bytes, int flag) { - return sess->ServerAsyncIOEventHandler(in_bytes, flag); - }; -} -// Runtime environment -struct RPCEnv { - public: - RPCEnv() { - NSString* path = NSTemporaryDirectory(); - base_ = [path UTF8String]; - if (base_[base_.length() - 1] != '/') { - base_ = base_ + '/'; - } - } - // Get Path. - std::string GetPath(const std::string& file_name) { return base_ + file_name; } - - private: - std::string base_; -}; - -void LaunchSyncServer() { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - NSString* base = [bundle privateFrameworksPath]; - NSString* path = [base stringByAppendingPathComponent:@"tvm/rpc_config.txt"]; - std::string name = [path UTF8String]; - std::ifstream fs(name, std::ios::in); - std::string url, key; - int port; - ICHECK(fs >> url >> port >> key) << "Invalid RPC config file " << name; - RPCConnect(url, port, "server:" + key, TVMArgs(nullptr, nullptr, 0))->ServerLoop(); -} +} // namespace detail TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); + static const std::string base_ = NSTemporaryDirectory().UTF8String; + const std::string path = args[0]; + *rv = base_ + "/" + path; }); TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { @@ -206,11 +118,3 @@ void Init(const std::string& name) { } // namespace runtime } // namespace tvm - -@implementation TVMRuntime - -+ (void)launchSyncServer { - tvm::runtime::LaunchSyncServer(); -} - -@end diff --git a/apps/ios_rpc/tvmrpc/ViewController.h b/apps/ios_rpc/tvmrpc/ViewController.h index b188a87b20d33..5e2f2bbafa99e 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.h +++ b/apps/ios_rpc/tvmrpc/ViewController.h @@ -22,28 +22,9 @@ */ #import -#include "TVMRuntime.h" +#import "RPCServer.h" -@interface ViewController : UIViewController { - // input socket stream - NSInputStream* inputStream_; - // output socket stream - NSOutputStream* outputStream_; - // temporal receive buffer. - std::string recvBuffer_; - // Whether connection is initialized. - bool initialized_; - // Whether auto reconnect when a session is done. - bool auto_reconnect_; - // The key of the server. - std::string key_; - // Initial bytes to be send to remote - std::string initBytes_; - // Send pointer of initial bytes. - size_t initSendPtr_; - // Event handler. - tvm::runtime::FEventHandler handler_; -} +@interface ViewController : UIViewController @property(weak, nonatomic) IBOutlet UITextField* proxyURL; @property(weak, nonatomic) IBOutlet UITextField* proxyPort; @@ -52,6 +33,7 @@ @property(weak, nonatomic) IBOutlet UITextView* infoText; - (IBAction)connect:(id)sender; -- (IBAction)disconnect:(id)sender; +@property(retain, nonatomic) IBOutlet UIButton* ConnectButton; +@property(retain, nonatomic) IBOutlet UISegmentedControl* ModeSelector; @end diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 4218a6a19d902..3f8c647fa4f29 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -22,168 +22,147 @@ */ #import "ViewController.h" -#include +#import "RPCArgs.h" -@implementation ViewController - -- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event { - std::string buffer; - switch (event) { - case NSStreamEventOpenCompleted: { - self.statusLabel.text = @"Connected"; - break; - } - case NSStreamEventHasBytesAvailable: - if (strm == inputStream_) { - [self onReadAvailable]; - } - break; - case NSStreamEventHasSpaceAvailable: { - if (strm == outputStream_) { - [self onWriteAvailable]; - } - break; - } - case NSStreamEventErrorOccurred: { - NSLog(@"%@", [strm streamError].localizedDescription); - break; - } - case NSStreamEventEndEncountered: { - [self close]; - // auto reconnect when normal end. - [self open]; - break; - } - default: { - NSLog(@"Unknown event"); - } - } +@implementation ViewController { + // server implementation + RPCServer* server_; + // Button state. True - push will start connection, false - push will disconnect + bool to_connect_; } -- (void)onReadAvailable { - constexpr int kRPCMagic = 0xff271; - if (!initialized_) { - int code; - size_t nbytes = [inputStream_ read:reinterpret_cast(&code) maxLength:sizeof(code)]; - if (nbytes != sizeof(code)) { - self.infoText.text = @"Fail to receive remote confirmation code."; - [self close]; - } else if (code == kRPCMagic + 2) { - self.infoText.text = @"Proxy server cannot find client that matches the key"; - [self close]; - } else if (code == kRPCMagic + 1) { - self.infoText.text = @"Proxy server already have another server with same key"; - [self close]; - } else if (code != kRPCMagic) { - self.infoText.text = @"Given address is not a TVM RPC Proxy"; - [self close]; - } else { - initialized_ = true; - self.statusLabel.text = @"Proxy connected."; - ICHECK(handler_ != nullptr); - } - } - const int kBufferSize = 4 << 10; - if (initialized_) { - while ([inputStream_ hasBytesAvailable]) { - recvBuffer_.resize(kBufferSize); - uint8_t* bptr = reinterpret_cast(&recvBuffer_[0]); - size_t nbytes = [inputStream_ read:bptr maxLength:kBufferSize]; - recvBuffer_.resize(nbytes); - int flag = 1; - if ([outputStream_ hasSpaceAvailable]) { - flag |= 2; - } - // always try to write - try { - flag = handler_(recvBuffer_, flag); - if (flag == 2) { - [self onShutdownReceived]; - } - } catch (const tvm::Error& e) { - [self close]; - } - } +- (void)viewDidLoad { + // To handle end editing events + self.proxyURL.delegate = self; + self.proxyPort.delegate = self; + self.proxyKey.delegate = self; + + RPCArgs args = get_current_rpc_args(); + self.proxyURL.text = @(args.host_url); + self.proxyPort.text = @(args.host_port).stringValue; + self.proxyKey.text = @(args.key); + + self.ModeSelector.selectedSegmentIndex = args.server_mode; + self->to_connect_ = true; + + // Add border to button + void (^addBorder)(UIButton* btn) = ^(UIButton* btn) { + btn.layer.borderWidth = 2.0f; + btn.layer.borderColor = self.ConnectButton.currentTitleColor.CGColor; + btn.layer.cornerRadius = 10; + }; + addBorder(self.ConnectButton); + + // Connect to tracker immediately + if (args.immediate_connect) { + [self disableUIInteraction]; + [self open]; } } -- (void)onShutdownReceived { - [self close]; -} +/*! + * \brief Disable all UI elements + */ +- (void)disableUIInteraction { + void (^disable)(UITextField* field) = ^(UITextField* field) { + field.enabled = NO; + field.backgroundColor = [UIColor lightGrayColor]; + }; -- (void)onWriteAvailable { - if (initSendPtr_ < initBytes_.length()) { - initSendPtr_ += [outputStream_ write:reinterpret_cast(&initBytes_[initSendPtr_]) - maxLength:(initBytes_.length() - initSendPtr_)]; - } - if (initialized_) { - try { - std::string dummy; - int flag = handler_(dummy, 2); - if (flag == 2) { - [self onShutdownReceived]; - } - } catch (const tvm::Error& e) { - [self close]; - } - } + void (^disableButton)(UIButton* btn) = ^(UIButton* btn) { + btn.enabled = NO; + btn.layer.borderColor = btn.currentTitleColor.CGColor; + }; + + disable(self.proxyURL); + disable(self.proxyPort); + disable(self.proxyKey); + disableButton(self.ConnectButton); + self.ModeSelector.enabled = NO; } +/*! + * \brief Start RPC server + */ - (void)open { - constexpr int kRPCMagic = 0xff271; - NSLog(@"Connecting to the proxy server.."); - // Initialize the data states. - key_ = [self.proxyKey.text UTF8String]; - key_ = "server:" + key_; - std::ostringstream os; - int rpc_magic = kRPCMagic; - os.write(reinterpret_cast(&rpc_magic), sizeof(rpc_magic)); - int keylen = static_cast(key_.length()); - os.write(reinterpret_cast(&keylen), sizeof(keylen)); - os.write(key_.c_str(), key_.length()); - initialized_ = false; - initBytes_ = os.str(); - initSendPtr_ = 0; - // Initialize the network. - CFReadStreamRef readStream; - CFWriteStreamRef writeStream; - CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.proxyURL.text, - [self.proxyPort.text intValue], &readStream, &writeStream); - inputStream_ = (NSInputStream*)readStream; - outputStream_ = (NSOutputStream*)writeStream; - [inputStream_ setDelegate:self]; - [outputStream_ setDelegate:self]; - [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; - [outputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; - [outputStream_ open]; - [inputStream_ open]; - handler_ = tvm::runtime::CreateServerEventHandler(outputStream_, key_, "%toinit"); - ICHECK(handler_ != nullptr); + RPCArgs args = get_current_rpc_args(); + + RPCServerMode server_mode = static_cast(self.ModeSelector.selectedSegmentIndex); + + server_ = [RPCServer serverWithMode:server_mode]; + server_.host = self.proxyURL.text; + server_.port = self.proxyPort.text.intValue; + server_.key = self.proxyKey.text; + server_.custom_addr = [NSString stringWithUTF8String:args.custom_addr]; + server_.delegate = self; + + [server_ start]; + self.infoText.text = @""; self.statusLabel.text = @"Connecting..."; } +/*! + * \brief Stop RPC server + */ - (void)close { - NSLog(@"Closing the streams."); - [inputStream_ close]; - [outputStream_ close]; - [inputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; - [outputStream_ removeFromRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; - [inputStream_ setDelegate:nil]; - [outputStream_ setDelegate:nil]; - inputStream_ = nil; - outputStream_ = nil; - handler_ = nullptr; - self.statusLabel.text = @"Disconnected"; + [server_ stop]; + self.statusLabel.text = @"Disconnecting..."; } +#pragma mark - Button responders +/*! + * \brief Connect/disconnect button handler + */ - (IBAction)connect:(id)sender { - [self open]; - [[self view] endEditing:YES]; + [[self view] endEditing:YES]; // to hide keyboard + (to_connect_ ^= true) ? [self close] : [self open]; + [self.ConnectButton setTitle:to_connect_ ? @"Connect" : @"Disconenct" + forState:UIControlStateNormal]; } -- (IBAction)disconnect:(id)sender { - [self close]; +#pragma mark - UITextFieldDelegate + +- (BOOL)textFieldShouldReturn:(UITextField*)textField { + [[self view] endEditing:YES]; // to hide keyboard on ret key + return FALSE; +} + +- (void)textFieldDidEndEditing:(UITextField*)textField { + // Update values in app arg cache + RPCArgs args = get_current_rpc_args(); + args.host_url = [self.proxyURL.text UTF8String]; + args.host_port = [self.proxyPort.text intValue]; + args.key = [self.proxyKey.text UTF8String]; + set_current_rpc_args(args); +} + +#pragma mark - RPCServerEvenlListener + +- (void)onError:(NSString*)msg { + dispatch_sync(dispatch_get_main_queue(), ^{ + self.infoText.text = [NSString stringWithFormat:@"Error: %@", msg]; + }); +} + +- (void)onStatusChanged:(RPCServerStatus)status { + dispatch_sync(dispatch_get_main_queue(), ^{ + switch (status) { + case RPCServerStatus_Connected: + if (self.ModeSelector.selectedSegmentIndex == RPCServerMode_Standalone) { + self.infoText.text = [NSString + stringWithFormat:@"IP: %@\nPort: %d", server_.device_addr, server_.actual_port]; + } + self.statusLabel.text = @"Connected"; + break; + case RPCServerStatus_Disconnected: + self.statusLabel.text = @"Disconnected"; + break; + default: + // Nothing + break; + } + }); } @end diff --git a/apps/ios_rpc/tvmrpc/main.m b/apps/ios_rpc/tvmrpc/main.m index d971ce37998bd..4d02f02849d22 100644 --- a/apps/ios_rpc/tvmrpc/main.m +++ b/apps/ios_rpc/tvmrpc/main.m @@ -23,8 +23,10 @@ #import #import "AppDelegate.h" +#import "RPCArgs.h" int main(int argc, char * argv[]) { + update_rpc_args(argc, argv); @autoreleasepool { return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); } diff --git a/apps/ios_rpc/tvmrpcLauncher/Info.plist b/apps/ios_rpc/tvmrpcLauncher/Info.plist deleted file mode 100644 index 45eb19371f178..0000000000000 --- a/apps/ios_rpc/tvmrpcLauncher/Info.plist +++ /dev/null @@ -1,39 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - CFBundleDevelopmentRegion - en - CFBundleExecutable - $(EXECUTABLE_NAME) - CFBundleIdentifier - $(PRODUCT_BUNDLE_IDENTIFIER) - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - $(PRODUCT_NAME) - CFBundlePackageType - BNDL - CFBundleShortVersionString - 1.0 - CFBundleVersion - 1 - - diff --git a/apps/microtvm/zephyr/template_project/boards.json b/apps/microtvm/zephyr/template_project/boards.json new file mode 100644 index 0000000000000..aabed33221503 --- /dev/null +++ b/apps/microtvm/zephyr/template_project/boards.json @@ -0,0 +1,62 @@ +{ + "mimxrt1050_evk": { + "board": "mimxrt1050_evk", + "model": "imxrt10xx", + "is_qemu": false, + "fpu": true + }, + "mps2_an521": { + "board": "mps2_an521", + "model": "mps2_an521", + "is_qemu": true, + "fpu": false + }, + "nrf5340dk_nrf5340_cpuapp": { + "board": "nrf5340dk_nrf5340_cpuapp", + "model": "nrf5340dk", + "is_qemu": false, + "fpu": true + }, + "nucleo_f746zg": { + "board": "nucleo_f746zg", + "model": "stm32f746xx", + "is_qemu": false, + "fpu": true + }, + "nucleo_l4r5zi": { + "board": "nucleo_l4r5zi", + "model": "stm32l4r5zi", + "is_qemu": false, + "fpu": true + }, + "qemu_cortex_r5": { + "board": "qemu_cortex_r5", + "model": "zynq_mp_r5", + "is_qemu": true, + "fpu": true + }, + "qemu_riscv32": { + "board": "qemu_riscv32", + "model": "host", + "is_qemu": true, + "fpu": true + }, + "qemu_riscv64": { + "board": "qemu_riscv64", + "model": "host", + "is_qemu": true, + "fpu": true + }, + "qemu_x86": { + "board": "qemu_x86", + "model": "host", + "is_qemu": true, + "fpu": true + }, + "stm32f746g_disco": { + "board": "stm32f746g_disco", + "model": "stm32f746xx", + "is_qemu": false, + "fpu": true + } +} diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index f2e091b2f5b5d..f700b5774c720 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -35,6 +35,7 @@ import tempfile import threading import time +import json import serial import serial.tools.list_ports @@ -57,46 +58,16 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() + +BOARDS = API_SERVER_DIR / "boards.json" + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. -BOARD_PROPERTIES = { - "qemu_x86": { - "board": "qemu_x86", - "model": "host", - }, - "qemu_riscv32": { - "board": "qemu_riscv32", - "model": "host", - }, - "qemu_riscv64": { - "board": "qemu_riscv64", - "model": "host", - }, - "mps2_an521": { - "board": "mps2_an521", - "model": "mps2_an521", - }, - "nrf5340dk_nrf5340_cpuapp": { - "board": "nrf5340dk_nrf5340_cpuapp", - "model": "nrf5340dk", - }, - "stm32f746g_disco": { - "board": "stm32f746g_disco", - "model": "stm32f746xx", - }, - "nucleo_f746zg": { - "board": "nucleo_f746zg", - "model": "stm32f746xx", - }, - "nucleo_l4r5zi": { - "board": "nucleo_l4r5zi", - "model": "stm32l4r5zi", - }, - "qemu_cortex_r5": { - "board": "qemu_cortex_r5", - "model": "zynq_mp_r5", - }, -} +try: + with open(BOARDS) as boards: + BOARD_PROPERTIES = json.load(boards) +except FileNotFoundError: + raise FileNotFoundError(f"Board file {{{BOARDS}}} does not exist.") def check_call(cmd_args, *args, **kwargs): @@ -191,6 +162,7 @@ def _get_device_args(options): "nucleo_l4r5zi": {"idVendor": 0x0483, "idProduct": 0x374B}, "nucleo_f746zg": {"idVendor": 0x0483, "idProduct": 0x374B}, "stm32f746g_disco": {"idVendor": 0x0483, "idProduct": 0x374B}, + "mimxrt1050_evk": {"idVendor": 0x1366, "idProduct": 0x0105}, } @@ -290,9 +262,8 @@ def _get_nrf_device_args(options): help="Name of the Zephyr board to build for.", ), server.ProjectOption( - "zephyr_model", - choices=[board["model"] for _, board in BOARD_PROPERTIES.items()], - help="Name of the model for each Zephyr board.", + "config_main_stack_size", + help="Sets CONFIG_MAIN_STACK_SIZE for Zephyr board.", ), ] @@ -351,13 +322,9 @@ def _create_prj_conf(self, project_dir, options): if self._has_fpu(options["zephyr_board"]): f.write("# For models with floating point.\n" "CONFIG_FPU=y\n" "\n") - main_stack_size = None - if self._is_qemu(options) and options["project_type"] == "host_driven": - main_stack_size = 1536 - # Set main stack size, if needed. - if main_stack_size is not None: - f.write(f"CONFIG_MAIN_STACK_SIZE={main_stack_size}\n") + if options.get("config_main_stack_size") is not None: + f.write(f"CONFIG_MAIN_STACK_SIZE={options['config_main_stack_size']}\n") f.write("# For random number generation.\n" "CONFIG_TEST_RANDOM_GENERATOR=y\n") @@ -384,6 +351,9 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # by launching the copy. shutil.copy2(__file__, project_dir / os.path.basename(__file__)) + # Copy boards.json file to generated project. + shutil.copy2(BOARDS, project_dir / BOARDS.name) + # Place Model Library Format tarball in the special location, which this script uses to decide # whether it's being invoked in a template or generated project. project_model_library_format_tar_path = project_dir / MODEL_LIBRARY_FORMAT_RELPATH @@ -471,20 +441,10 @@ def _is_qemu(cls, options): or options["zephyr_board"] in cls._KNOWN_QEMU_ZEPHYR_BOARDS ) - _KNOWN_FPU_ZEPHYR_BOARDS = ( - "nucleo_f746zg", - "nucleo_l4r5zi", - "nrf5340dk_nrf5340_cpuapp", - "qemu_cortex_r5", - "qemu_riscv32", - "qemu_riscv64", - "qemu_x86", - "stm32f746g_disco", - ) - @classmethod def _has_fpu(cls, zephyr_board): - return zephyr_board in cls._KNOWN_FPU_ZEPHYR_BOARDS + fpu_boards = [name for name, board in BOARD_PROPERTIES.items() if board["fpu"]] + return zephyr_board in fpu_boards def flash(self, options): if self._is_qemu(options): @@ -586,6 +546,10 @@ def _find_openocd_serial_port(cls, options): return ports[0].device + @classmethod + def _find_jlink_serial_port(cls, options): + return cls._find_openocd_serial_port(options) + @classmethod def _find_serial_port(cls, options): flash_runner = _get_flash_runner() @@ -596,9 +560,10 @@ def _find_serial_port(cls, options): if flash_runner == "openocd": return cls._find_openocd_serial_port(options) - raise FlashRunnerNotSupported( - f"Don't know how to deduce serial port for flash runner {flash_runner}" - ) + if flash_runner == "jlink": + return cls._find_jlink_serial_port(options) + + raise RuntimeError(f"Don't know how to deduce serial port for flash runner {flash_runner}") def __init__(self, options): self._options = options diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake index aaddfb054366e..8f3f638309cd6 100644 --- a/cmake/utils/FindCUDA.cmake +++ b/cmake/utils/FindCUDA.cmake @@ -89,12 +89,16 @@ macro(find_cuda use_cuda use_cudnn) ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib NO_DEFAULT_PATH) + # search default path if cannot find cublas in non-default + find_library(CUDA_CUBLAS_LIBRARY cublas) find_library(CUDA_CUBLASLT_LIBRARY NAMES cublaslt cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib NO_DEFAULT_PATH) + # search default path if cannot find cublaslt in non-default + find_library(CUDA_CUBLASLT_LIBRARY NAMES cublaslt cublasLt) endif(MSVC) # find cuDNN diff --git a/docker/bash.sh b/docker/bash.sh index 372cfded8f89e..cbd71870747c1 100755 --- a/docker/bash.sh +++ b/docker/bash.sh @@ -22,7 +22,7 @@ # # Usage: docker/bash.sh [-i|--interactive] [--net=host] [-t|--tty] # [--mount MOUNT_DIR] [--repo-mount-point REPO_MOUNT_POINT] -# [--dry-run] +# [--dry-run] [--name NAME] # [--] [COMMAND] # # Usage: docker/bash.sh @@ -40,7 +40,7 @@ function show_usage() { cat < [--] [COMMAND] -h, --help @@ -85,6 +85,11 @@ Usage: docker/bash.sh [-i|--interactive] [--net=host] [-t|--tty] Print the docker command to be run, but do not execute it. +--name + + Set the name of the docker container, and the hostname that will + appear inside the container. + DOCKER_IMAGE_NAME The name of the docker container to be run. This can be an @@ -118,6 +123,7 @@ USE_NET_HOST=false DOCKER_IMAGE_NAME= COMMAND=bash MOUNT_DIRS=( ) +CONTAINER_NAME= # TODO(Lunderberg): Remove this if statement and always set to # "${REPO_DIR}". The consistent directory for Jenkins is currently @@ -180,6 +186,15 @@ while (( $# )); do shift ;; + --name) + if [[ -n "$2" ]]; then + CONTAINER_NAME="$2" + shift 2 + else + parse_error 'ERROR: --name requires a non empty argument' + fi + ;; + --dry-run) DRY_RUN=true shift @@ -312,6 +327,11 @@ if ${TTY}; then DOCKER_FLAGS+=( --tty ) fi +# Setup the docker name and the hostname inside the container +if [[ ! -z "${CONTAINER_NAME}" ]]; then + DOCKER_FLAGS+=( --name ${CONTAINER_NAME} --hostname ${CONTAINER_NAME}) +fi + # Expose external directories to the docker container for MOUNT_DIR in ${MOUNT_DIRS[@]+"${MOUNT_DIRS[@]}"}; do DOCKER_MOUNT+=( --volume "${MOUNT_DIR}:${MOUNT_DIR}" ) diff --git a/docker/build.sh b/docker/build.sh index 3b58bcc52a755..4e1a9b346895b 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -23,7 +23,8 @@ # Usage: build.sh [--tag ] # [--dockerfile ] [-it] # [--net=host] [--cache-from ] -# [--context-path ] [] +# [--name CONTAINER_NAME] [--context-path ] +# [] # # CONTAINER_TYPE: Type of the docker container used the run the build, # e.g. "ci_cpu", "ci_gpu" @@ -38,6 +39,9 @@ # IMAGE_NAME: An image to be as a source for cached layers when building the # Docker image requested. # +# CONTAINER_NAME: The name of the docker container, and the hostname that will +# appear inside the container. +# # CONTEXT_PATH: Path to be used for relative path resolution when building # the Docker images. # @@ -95,6 +99,12 @@ else echo "Using default context path: ${DOCKER_CONTEXT_PATH}" fi +if [[ "$1" == "--name" ]]; then + CI_DOCKER_EXTRA_PARAMS+=("--name ${2} --hostname ${2}") + echo "Using container name ${2}" + shift 2 +fi + if [[ ! -f "${DOCKERFILE_PATH}" ]]; then echo "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" exit 1 diff --git a/docker/install/ubuntu_install_core.sh b/docker/install/ubuntu_install_core.sh index 2a50afcf59850..f3e97cbf28b06 100755 --- a/docker/install/ubuntu_install_core.sh +++ b/docker/install/ubuntu_install_core.sh @@ -22,9 +22,17 @@ set -o pipefail # install libraries for building c++ core on ubuntu apt-get update && apt-get install -y --no-install-recommends \ - git make libgtest-dev cmake wget unzip libtinfo-dev libz-dev\ + git make google-mock libgtest-dev cmake wget unzip libtinfo-dev libz-dev \ libcurl4-openssl-dev libssl-dev libopenblas-dev g++ sudo \ apt-transport-https graphviz pkg-config curl - -cd /usr/src/gtest && cmake CMakeLists.txt && make && cp *.a /usr/lib +if [[ -d /usr/src/googletest ]]; then + # Single package source (Ubuntu 18.04) + # googletest is installed via libgtest-dev + cd /usr/src/googletest && cmake CMakeLists.txt && make && cp -v {googlemock,googlemock/gtest}/*.a /usr/lib +else + # Split source package (Ubuntu 16.04) + # libgtest-dev and google-mock + cd /usr/src/gtest && cmake CMakeLists.txt && make && cp -v *.a /usr/lib + cd /usr/src/gmock && cmake CMakeLists.txt && make && cp -v *.a /usr/lib +fi diff --git a/docker/install/ubuntu_install_paddle.sh b/docker/install/ubuntu_install_paddle.sh index 267d59105c063..c7f9d43a3bd40 100644 --- a/docker/install/ubuntu_install_paddle.sh +++ b/docker/install/ubuntu_install_paddle.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip install paddlepaddle==2.1.2 +pip install paddlepaddle==2.1.3 diff --git a/docker/lint.sh b/docker/lint.sh index e709bfb08445a..4bad5ea3b923c 100755 --- a/docker/lint.sh +++ b/docker/lint.sh @@ -51,6 +51,9 @@ function run_lint_step() { cpplint) cmd=( tests/lint/cpplint.sh ) ;; + flake8) + cmd=( tests/lint/flake8.sh ) + ;; pylint) cmd=( tests/lint/pylint.sh ) ;; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index fa1861051e2f9..715c96eb6ea52 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v return input; } +/*! + * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attrs Key/values attributes to add to \p input. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + */ +template +inline TFunc WithAttrs(TFunc input, Map attrs) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + for (const auto& pair : attrs) { + node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); + } + } else { + node->attrs = DictAttrs(std::move(attrs)); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 13b984d9cb355..5ee719f9964f8 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target"; * Type: String */ constexpr const char* kGlobalSymbol = "global_symbol"; + +/*! + * \brief The device type which will hold each of the functions parameters. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Array (but interpreted as Array) + */ +constexpr const char* kParamDeviceTypes = "param_device_types"; + +/*! + * \brief The device type which will hold the function result. + * + * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but + * may be included as an annotation on user programs. + * + * Type: Integer (but interpreted as DLDeviceType) + */ +constexpr const char* kResultDeviceType = "result_device_type"; + } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h new file mode 100644 index 0000000000000..08553a001374e --- /dev/null +++ b/include/tvm/meta_schedule/arg_info.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_ARG_INFO_H_ +#define TVM_META_SCHEDULE_ARG_INFO_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The argument information. */ +class ArgInfoNode : public runtime::Object { + public: + static constexpr const char* _type_key = "meta_schedule.ArgInfo"; + TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object); + + public: + /*! \brief Default destructor. */ + virtual ~ArgInfoNode() = default; + /*! \brief Converts the ArgInfo to its corresponding JSON representation. */ + virtual ObjectRef AsJSON() const = 0; +}; + +/*! + * \brief Managed reference to ArgInfoNode + * \sa ArgInfoNode + */ +class ArgInfo : public runtime::ObjectRef { + public: + /*! + * \brief Parse the argument information from a JSON object. + * \param json_obj The json object to parse. + * \return The argument information parsed. + */ + TVM_DLL static ArgInfo FromJSON(const ObjectRef& json_obj); + /*! + * \brief Extract a list of the argument information from PrimFunc. + * \param func The PrimFunc to get argument information from. + * \return An array of the argument information derived. + */ + TVM_DLL static Array FromPrimFunc(const tir::PrimFunc& func); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode); + + protected: + ArgInfo() = default; +}; + +/*! \brief The tensor argument information. */ +class TensorInfoNode : public ArgInfoNode { + public: + /*! \brief The data type of the tensor. */ + runtime::DataType dtype; + /*! \brief The shape of the tensor. */ + runtime::ShapeTuple shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "meta_schedule.TensorInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode); + + public: + ObjectRef AsJSON() const; +}; + +/*! + * \brief Managed reference to TensorInfoNode + * \sa TensorInfoNode + */ +class TensorInfo : public ArgInfo { + public: + /*! + * \brief Constructor of TensorInfo. + * \param dtype The data type of the tensor argument. + * \param shape The shape tuple of the tensor argument. + */ + TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape); + /*! + * \brief Parse the argument information from a JSON object. + * \param json_obj The json object to parse. + * \return The argument information parsed. + */ + TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_ARG_INFO_H_ diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index d0985071b773e..19358552df10a 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -26,7 +25,7 @@ namespace tvm { namespace meta_schedule { -/*! \brief The builder's input. */ +/*! \brief The builder's input, containing an IRModule and the target. */ class BuilderInputNode : public runtime::Object { public: /*! \brief The IRModule to be built. */ @@ -58,7 +57,7 @@ class BuilderInput : public runtime::ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BuilderInput, runtime::ObjectRef, BuilderInputNode); }; -/*! \brief The builder's output. */ +/*! \brief The builder's output, containing the artifact path or error message if any. */ class BuilderResultNode : public runtime::Object { public: /*! \brief The path to the built artifact. */ diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h new file mode 100644 index 0000000000000..7ba3c207e349f --- /dev/null +++ b/include/tvm/meta_schedule/database.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_DATABASE_H_ +#define TVM_META_SCHEDULE_DATABASE_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief A workload, i.e. an IRModule and its structural hash. */ +class WorkloadNode : public runtime::Object { + public: + /*! \brief The type of structural hash */ + using THashCode = size_t; + /*! \brief The workload's IRModule. */ + IRModule mod; + /*! \brief The workload's structural hash. */ + THashCode shash; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("mod", &mod); + // `shash` is not visited because TVM FFI doesn't support uint64_t + } + + static constexpr const char* _type_key = "meta_schedule.Workload"; + TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object); + + /*! + * \brief Export the workload to a JSON string. + * \return An array containing the structural hash and the base64 json string. + */ + ObjectRef AsJSON() const; +}; + +/*! + * \brief Managed reference to WorkloadNode. + * \sa WorkloadNode + */ +class Workload : public runtime::ObjectRef { + public: + using THashCode = WorkloadNode::THashCode; + /*! + * \brief Constructor of Workload. + * \param mod The workload's IRModule. + */ + TVM_DLL explicit Workload(IRModule mod); + /*! + * \brief Constructor of Workload. + * \param mod The workload's IRModule. + * \param shash The workload's structural hash. + */ + TVM_DLL explicit Workload(IRModule mod, THashCode shash); + /*! + * \brief Create a workload from a json object. + * \param json_obj The json object. + * \return The created workload. + */ + TVM_DLL static Workload FromJSON(const ObjectRef& json_obj); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Workload, runtime::ObjectRef, WorkloadNode); +}; + +/*! \brief The hash method for Workload */ +struct WorkloadHash { + size_t operator()(const Workload& a) const { return a->shash; } +}; + +/*! \brief The equality check for Workload */ +struct WorkloadEqual { + bool operator()(const Workload& a, const Workload& b) const { + return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + } +}; + +/*! \brief The class of tuning records. */ +class TuningRecordNode : public runtime::Object { + public: + /*! \brief The trace tuned. */ + tir::Trace trace; + /*! \brief The profiling result in seconds. */ + Array run_secs; + /*! \brief The workload. */ + Workload workload{nullptr}; + /*! \brief The target for tuning. */ + Target target; + /*! \brief The argument information. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("trace", &trace); + v->Visit("run_secs", &run_secs); + v->Visit("workload", &workload); + v->Visit("target", &target); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.TuningRecord"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); + + /*! + * \brief Export the tuning record to a JSON string. + * \return An array containing the trace, running secs, serialized target, and + * argument information. + */ + ObjectRef AsJSON() const; +}; + +/*! + * \brief The managed reference of TuningRecordNode. + * \sa TuningRecordNode + */ +class TuningRecord : public runtime::ObjectRef { + public: + /*! + \brief Constructor of a tuning record. + \param trace The trace of the tuning record. + \param run_secs The running time of the tuning record. + \param workload The workload of the tuning record. + \param target The target of the tuning record. + \param args_info The argument information of the tuning record. + */ + TVM_DLL explicit TuningRecord(tir::Trace trace, Array run_secs, Workload workload, + Target target, Array args_info); + /*! + * \brief Create a tuning record from a json object. + * \param json_obj The json object. + * \param workload The workload. + * \return The tuning record created. + */ + TVM_DLL static TuningRecord FromJSON(const ObjectRef& json_obj, const Workload& workload); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TuningRecord, runtime::ObjectRef, TuningRecordNode); +}; + +/* \brief The abstract interface of database. */ +class DatabaseNode : public runtime::Object { + public: + /*! \brief Default destructor */ + virtual ~DatabaseNode() = default; + /*! + * \brief Look up or add workload to the database if missing. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + virtual Workload CommitWorkload(const IRModule& mod) = 0; + /*! + * \brief Add a tuning record to the database. + * \param record The tuning record to be added. + */ + virtual void CommitTuningRecord(const TuningRecord& record) = 0; + /*! + * \brief Get the top K tuning records of given workload from the database. + * \param workload The workload to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + virtual Array GetTopK(const Workload& workload, int top_k) = 0; + /*! + * \brief Get the size of the database. + * \return The size of the database. + */ + virtual int64_t Size() = 0; + + static constexpr const char* _type_key = "meta_schedule.Database"; + TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); +}; + +/*! \brief The database with customized methods on the python-side. */ +class PyDatabaseNode : public DatabaseNode { + public: + /*! + * \brief The function type of `CommitWorkload` method. + * \param mod The IRModule to be searched for or added. + * \return The workload corresponding to the given IRModule. + */ + using FCommitWorkload = runtime::TypedPackedFunc; + /*! + * \brief The function type of `CommitTuningRecord` method. + * \param record The tuning record to be added. + */ + using FCommitTuningRecord = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GetTopK` method. + * \param workload The workload to be searched for. + * \param top_k The number of top records to be returned. + * \return An array of top K tuning records for the given workload. + */ + using FGetTopK = runtime::TypedPackedFunc(const Workload&, int)>; + /*! + * \brief The function type of `Size` method. + * \return The size of the database. + */ + using FSize = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `CommitWorkload` function. */ + FCommitWorkload f_commit_workload; + /*! \brief The packed function to the `CommitTuningRecord` function. */ + FCommitTuningRecord f_commit_tuning_record; + /*! \brief The packed function to the `GetTopK` function. */ + FGetTopK f_get_top_k; + /*! \brief The packed function to the `Size` function. */ + FSize f_size; + + void VisitAttrs(tvm::AttrVisitor* v) { + // PackedFuncs are all not visited, because the reflection system doesn't take care of them, + // so it cannot be accessible on the python side. If there is such need from the future, + // we can then add corresponding accessor methods to help access on python. + // + // `f_commit_workload` is not visited + // `f_commit_tuning_record` is not visited + // `f_get_top_k` is not visited + // `f_size` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.PyDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode); + + Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); } + + void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); } + + Array GetTopK(const Workload& workload, int top_k) final { + return f_get_top_k(workload, top_k); + } + + int64_t Size() final { return f_size(); } +}; + +/*! + * \brief Managed reference to DatabaseNode. + * \sa DatabaseNode + */ +class Database : public runtime::ObjectRef { + public: + /*! + * \brief Create a default database that uses JSON file for tuning records. + * \param path_workload The path to the workload table. + * \param path_tuning_record The path to the database table. + * \param allow_missing Whether to create new file when the given path is not found. + */ + TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, + bool allow_missing); + /*! + * \brief Create a database with customized methods on the python-side. + * \param f_commit_workload The packed function of `CommitWorkload`. + * \param f_commit_tuning_record The packed function of `CommitTuningRecord`. + * \param f_get_top_k The packed function of `GetTopK`. + * \param f_size The packed function of `Size`. + * \return The created database. + */ + TVM_DLL static Database PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, + PyDatabaseNode::FGetTopK f_get_top_k, + PyDatabaseNode::FSize f_size); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_DATABASE_H_ diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h new file mode 100644 index 0000000000000..a45a4898d64ae --- /dev/null +++ b/include/tvm/meta_schedule/runner.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_RUNNER_H_ +#define TVM_META_SCHEDULE_RUNNER_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The runner's input. */ +class RunnerInputNode : public runtime::Object { + public: + /*! \brief The path to the built artifact. */ + String artifact_path; + /*! \brief The type of device. */ + String device_type; + /*! \brief The argument information. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("artifact_path", &artifact_path); + v->Visit("device_type", &device_type); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerInputNode + * \sa RunnerInputNode + */ +class RunnerInput : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of RunnerInput + * \param artifact_path The path to the built artifact. + * \param device_type The type of device. + * \param args_info The argument information. + */ + TVM_DLL explicit RunnerInput(String artifact_path, String device_type, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerInput, runtime::ObjectRef, RunnerInputNode); +}; + +/*! \brief The runner's output. */ +class RunnerResultNode : public runtime::Object { + public: + /*! \brief The run time in seconds.*/ + Optional> run_secs; + /*! \brief The error message, if any. */ + Optional error_msg; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("run_secs", &run_secs); + v->Visit("error_msg", &error_msg); + } + + static constexpr const char* _type_key = "meta_schedule.RunnerResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerResultNode + * \sa RunnerResultNode + */ +class RunnerResult : public runtime::ObjectRef { + public: + /*! + * \brief Constructor + * \brief The run time in seconds. + * \brief The error message, if any. + */ + TVM_DLL explicit RunnerResult(Optional> run_secs, Optional error_msg); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerResult, runtime::ObjectRef, RunnerResultNode); +}; + +/*! + * \brief A class to asynchronously fetch runner's output. + * \note The API design is consistent with python's concurrent.futures.Future: + * https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future + */ +class RunnerFutureNode : public runtime::Object { + public: + /*! + * \brief The function type to check whether the runner has finished. + * \return Whether the runner's output is ready. + */ + using FDone = runtime::TypedPackedFunc; + /*! + * \brief The function type to fetch runner output if it is ready. + * \return The runner's output. + */ + using FResult = runtime::TypedPackedFunc; + + /*! \brief The packed function to check whether the runner has finished. */ + FDone f_done; + /*! \brief The packed function to fetch runner output if it is ready. */ + FResult f_result; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_done` is not visited + // `f_result` is not visited + } + + /*! + * \brief Check whether the runner has finished. + * \return A boolean indicating whether the runner has finished. + */ + bool Done() const { return f_done(); } + /*! + * \brief Fetch the runner's output if it is ready. + * \return The runner's output. + */ + RunnerResult Result() const { return f_result(); } + + static constexpr const char* _type_key = "meta_schedule.RunnerFuture"; + TVM_DECLARE_FINAL_OBJECT_INFO(RunnerFutureNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerFutureNode + * \sa RunnerFutureNode + */ +class RunnerFuture : public runtime::ObjectRef { + public: + using FDone = RunnerFutureNode::FDone; + using FResult = RunnerFutureNode::FResult; + + /*! + * \brief Constructor of RunnerFuture + * \param f_done The packed function to check whether the runner has finished. + * \param f_result The packed function to fetch runner output if it is ready. + */ + TVM_DLL explicit RunnerFuture(FDone f_done, FResult f_result); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RunnerFuture, runtime::ObjectRef, + RunnerFutureNode); +}; + +/*! \brief The abstract runner interface. */ +class RunnerNode : public runtime::Object { + public: + /*! + * \brief The function type to run the built artifacts and get runner futures. + * \param input The runner's inputs. + * \return The runner futures. + * \sa RunnerFuture + */ + using FRun = runtime::TypedPackedFunc(Array)>; + + /*! \brief Default destructor */ + virtual ~RunnerNode() = default; + + /*! + * \brief Run the built artifact and get runner futures. + * \param runner_inputs The runner's inputs. + * \return The runner futures. + */ + virtual Array Run(Array runner_inputs) = 0; + + static constexpr const char* _type_key = "meta_schedule.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, runtime::Object); +}; + +/*! + * \brief Managed reference to RunnerNode + * \sa RunnerNode + */ +class Runner : public runtime::ObjectRef { + public: + using FRun = RunnerNode::FRun; + + /*! + * \brief Create a runner with customized build method on the python-side. + * \param f_run The packed function to run the built artifacts and get runner futures. + * \return The runner created. + */ + TVM_DLL static Runner PyRunner(FRun f_run); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Runner, runtime::ObjectRef, RunnerNode); +}; + +/*! \brief An abstract runner with customized build method on the python-side. */ +class PyRunnerNode : public RunnerNode { + public: + /*! \brief The packed function to run the built artifacts and get runner futures. */ + FRun f_run; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_run` is not visited + } + + Array Run(Array runner_inputs) final { return f_run(runner_inputs); } + + static constexpr const char* _type_key = "meta_schedule.PyRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_RUNNER_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h new file mode 100644 index 0000000000000..941dae4336e1b --- /dev/null +++ b/include/tvm/meta_schedule/search_strategy.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ +#define TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ + +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +/*! \brief The schedule (with input shapes) to be measured. */ +class MeasureCandidateNode : public runtime::Object { + public: + /*! \brief The schedule for measurement. */ + tir::Schedule sch; + /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ + Array args_info; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("sch", &sch); + v->Visit("args_info", &args_info); + } + + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); +}; + +/*! + * \brief Managed reference to MeasureCandidateNode. + * \sa MeasureCandidateNode + */ +class MeasureCandidate : public runtime::ObjectRef { + public: + /*! + * \brief Constructor of MeasureCandidate. + * \param sch The schedule for measurement. + * \param args_info The argument information, e.g., (shape, dtype) for tensors. + */ + TVM_DLL MeasureCandidate(tir::Schedule sch, Array args_info); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MeasureCandidate, ObjectRef, MeasureCandidateNode); +}; + +/*! + * \brief The search strategy for measure candidates generation. + * \note The relationship between SearchStrategy and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ +class SearchStrategyNode : public runtime::Object { + public: + /*! \brief Virtual destructor */ + virtual ~SearchStrategyNode() = default; + + /*! + * \brief Initialize the search strategy with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Pre-tuning for the search strategy. + * \param design_spaces The design spaces for pre-tuning. + * \note Pre-tuning is supposed to be called before the tuning process and after the + * initialization. Because the search strategy is stateful, we can always call pretuning + * and reset the search strategy. + */ + virtual void PreTuning(const Array& design_spaces) = 0; + + /*! + * \brief Post-tuning for the search strategy. + * \note Post-tuning is supposed to be called after the tuning process and before we reset the + * search strategy with another pre-tuning. Post-tuning can be empty. + */ + virtual void PostTuning() = 0; + + /*! + * \brief Generate measure candidates from design spaces for measurement. + * \return The measure candidates generated, nullptr if finished. + */ + virtual Optional> GenerateMeasureCandidates() = 0; + + /*! + * \brief Update the search strategy with measurement results. + * \param results The measurement results from the runner. + */ + virtual void NotifyRunnerResults(const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); +}; + +/*! \brief The python side customizable class for measure candidate generation */ +class PySearchStrategyNode : public SearchStrategyNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `PreTuning` method. + * \param design_spaces The design spaces for pre-tuning. + */ + using FPreTuning = runtime::TypedPackedFunc&)>; + /*! \brief The function type of `PostTuning` method. */ + using FPostTuning = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateMeasureCandidates` method. + * \return The measure candidates generated, nullptr if finished. + */ + using FGenerateMeasureCandidates = runtime::TypedPackedFunc>()>; + /*! + * \brief The function type of `NotifyRunnerResults` method. + * \param results The measurement results from the runner. + */ + using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` method. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `PreTuning` method. */ + FPreTuning f_pre_tuning; + /*! \brief The packed function to the `PostTuning` method. */ + FPostTuning f_post_tuning; + /*! \brief The packed function to the `GenerateMeasureCandidates` method. */ + FGenerateMeasureCandidates f_generate_measure_candidates; + /*! \brief The packed function to the `NotifyRunnerResults` method. */ + FNotifyRunnerResults f_notify_runner_results; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_pre_tuning` is not visited + // `f_post_tuning` is not visited + // `f_generate_measure_candidates` is not visited + // `f_notify_runner_results` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + this->f_initialize_with_tune_context(context); + } + + void PreTuning(const Array& design_spaces) final { + this->f_pre_tuning(design_spaces); + } + + void PostTuning() final { this->f_post_tuning(); } + + Optional> GenerateMeasureCandidates() final { + return this->f_generate_measure_candidates(); + } + + void NotifyRunnerResults(const Array& results) final { + this->f_notify_runner_results(results); + } + + static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySearchStrategyNode, SearchStrategyNode); +}; + +/*! + * \brief Managed reference to SearchStrategyNode. + * \sa SearchStrategyNode + */ +class SearchStrategy : public runtime::ObjectRef { + public: + /*! + * \brief Create a search strategy with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_pre_tuning The packed function of `PreTuning`. + * \param f_post_tuning The packed function of `PostTuning`. + * \param f_generate_measure_candidates The packed function of `GenerateMeasureCandidates`. + * \param f_notify_runner_results The packed function of `NotifyRunnerResults`. + * \return The search strategy created. + */ + TVM_DLL static SearchStrategy PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results); + + /*! + * \brief Constructor of replay trace search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for trace replaying. + */ + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SEARCH_STRATEGY_H_ diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h new file mode 100644 index 0000000000000..3dc181e05d8a7 --- /dev/null +++ b/include/tvm/meta_schedule/space_generator.h @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SPACE_GENERATOR_H_ +#define TVM_META_SCHEDULE_SPACE_GENERATOR_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +/*! + * \brief The abstract class for design space generation. + * \note The relationship between SpaceGenerator and other classes are as follows: + ┌──────────────────────────────────────────────────────────────┐ + ┌──┴───────────────────────────────────────────────────────────┐ │ +┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ +│ ┌─────────────────────┐ │ │ │ +│ │ │ Generate │ │ │ +│ │ Space Generator ├──────────────┐ │ │ │ +│ │ │ │ │ │ │ +│ └─────────────────────┘ ▼ │ │ │ +│ Design Space │ │ │ +│ ┌─────────────────────┐ │ │ │ │ +│ Generate │ │ Pretuning │ │ │ │ +│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ +│ │ │ │ │ ├──┘ +│ │ └─────────────────────┘ ├──┘ +└────┼─────────────────────────────────────────────────────────┘ + │ + │ +┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ +│ │ ┌───────────┐ │ +│ │ Send to │ │ Send to │ +│ ▼ ┌─────────────►│ Builder ├──────────┐ │ +│ Measure Candidate │ Builder │ │ Runner │ │ +│ │ │ └───────────┘ │ │ +│ │ ┌────────────┴────────┐ │ │ +│ │ │ │ ┌───────────┐ │ │ +│ └────►│ Task Scheduler │ │ │ │ │ +│ │ │ │ Runner │◄─────────┘ │ +│ └─────────────────────┘ │ │ │ +│ ▲ └─────┬─────┘ │ +│ │ │ │ +│ └─── Runner Future ◄────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +*/ +class SpaceGeneratorNode : public Object { + public: + /*! \brief Default destructor */ + virtual ~SpaceGeneratorNode() = default; + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Generate design spaces given a module. + * \param mod The module used for design space generation. + * \return The generated design spaces, i.e., schedules. + */ + virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + + static constexpr const char* _type_key = "meta_schedule.SpaceGenerator"; + TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object); +}; + +/*! \brief The design space generator with customized methods on the python-side. */ +class PySpaceGeneratorNode : public SpaceGeneratorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateDesignSpace` method. + * \param mod The module used for design space generation. + * \return The generated design spaces, i.e., schedules. + */ + using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `GenerateDesignSpace` function. */ + FGenerateDesignSpace f_generate_design_space; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_generate_design_space` is not visited + } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + f_initialize_with_tune_context(tune_context); + } + + Array GenerateDesignSpace(const IRModule& mod) final { + return f_generate_design_space(mod); + } + + static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); +}; + +/*! + * \brief Managed reference to SpaceGeneratorNode. + * \sa SpaceGeneratorNode + */ +class SpaceGenerator : public ObjectRef { + protected: + SpaceGenerator() = default; + + public: + /*! + * \brief Create a design space generator with customized methods on the python-side. + * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. + * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PySpaceGenerator( + PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func, + PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func); + + /*! + * \brief Create a design space generator that is union of multiple design space generators. + * \param space_generators An array of design space generators to be unioned. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SPACE_GENERATOR_H_ diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h new file mode 100644 index 0000000000000..87a3a491c8f3d --- /dev/null +++ b/include/tvm/meta_schedule/tune_context.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_TUNE_CONTEXT_H_ +#define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +/*! \brief The auto tuning context. */ +class TuneContextNode : public runtime::Object { + public: + /*! \brief The workload to be tuned. */ + Optional mod; + /*! \brief The target to be tuned for. */ + Optional target; + /*! \brief The design space generator. */ + Optional space_generator; + /*! \brief The name of the tuning task. */ + Optional task_name; + /*! \brief The random state. */ + support::LinearCongruentialEngine::TRandState rand_state; + /*! \brief The number of threads to be used. */ + int num_threads; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("mod", &mod); + v->Visit("target", &target); + v->Visit("space_generator", &space_generator); + v->Visit("task_name", &task_name); + v->Visit("rand_state", &rand_state); + v->Visit("num_threads", &num_threads); + } + + static constexpr const char* _type_key = "meta_schedule.TuneContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); +}; + +/*! + * \brief Managed reference to TuneContextNode. + * \sa TuneContextNode + */ +class TuneContext : public runtime::ObjectRef { + public: + /*! + * \brief Constructor. + * \param mod The workload to be tuned. + * \param target The target to be tuned for. + * \param space_generator The design space generator. + * \param task_name The name of the tuning task. + * \param rand_state The random state. + * \param num_threads The number of threads to be used. + */ + TVM_DLL explicit TuneContext(Optional mod, // + Optional target, // + Optional space_generator, // + Optional task_name, // + support::LinearCongruentialEngine::TRandState rand_state, // + int num_threads); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TuneContext, ObjectRef, TuneContextNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_TUNE_CONTEXT_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 18e8db0ace228..4b9ae5dcc830c 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -229,7 +229,7 @@ class ReflectionVTable::Registry { }; #define TVM_REFLECTION_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflectiion + static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection /*! * \brief Directly register reflection VTable. diff --git a/include/tvm/parser/parser.h b/include/tvm/parser/parser.h index 7673eec2a337f..8c27220509057 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/parser/parser.h @@ -23,6 +23,7 @@ * \file parser.h * \brief A parser for TVM IR. */ +#include #include #include @@ -32,8 +33,11 @@ namespace tvm { namespace parser { -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module = Optional()); +using MetaTable = Map>; + +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module = Optional(), + const MetaTable& init_meta_table = MetaTable()); } // namespace parser } // namespace tvm diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index 8379e6471561d..85ac3f36ff607 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -32,15 +32,64 @@ namespace tvm { namespace relay { /*! - * \brief Options for the device annotation operators. + * \brief Attributes for the "on_device" special operator. + * + * The Relay call (aka 'annotation'): + * \code + * on_device(sub_expr, device_type=2) + * \endcode + * constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2 + * (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be + * executed and stored on a different device. If so the compiler will automatically insert a + * "device_copy" call to mediate the transition between devices. + * + * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: + * \code + * multiply(on_device(add(%x, %y), device_type=2), %z) + * \endcode + * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. + * The compiler will rewrite this to: + * \code + * multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * + * The Relay call + * \code + * on_device(sub_expr, device_type=2, is_fixed=True) + * \endcode + * is similar to the above, however the annotation itself must appear in an expression on the + * same device. The compiler will check the devices are consistent, and will not insert any + * "device_copy" call. This form of annotation shouldn't be necessary in user programs. However + * it is needed by the \p PlanDevices pass to fully specify the results of device planning so that + * the pass is idempotent. + * + * E.g.: The following program is equivalent to the above: + * \code + * let %a = on_device(add(%x, %y), device_type=2, is_fixed=True) + * multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z) + * \endcode + * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored + * on the GPU. */ struct OnDeviceAttrs : public tvm::AttrsNode { - int device_type; + // TODO(mbs): Replace device types with TargetDevice. + /*! \brief Device type on which argument expression should be evaluated. */ + int device_type = kInvalidDeviceType; + /*! + * \brief If true, the result device must also be \p device_type and device planning should + * not insert any "device_copy" calls to respect this annotation. + * + * This is used by the device planning pass itself when annotating the planned program. + */ + bool is_fixed = false; TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe("The virutal device/context type that an expression is annotated with.") + .describe("The type of the virtual device which should hold the expression result.") .set_default(0); + TVM_ATTR_FIELD(is_fixed) + .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + .set_default(false); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 7da92b3ff7639..f7b0a04f45fa8 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -35,6 +35,7 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { + // TODO(mbs): Should be TargetDevice. int dst_dev_type; int src_dev_type; diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index e3f9bad17ef5b..0e04b0936f249 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -60,6 +60,18 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { } }; // struct ExpandDimsAttrs +/*! \brief Attributes used in dynamic expand_dims operators */ +struct DynExpandDimsAttrs : public tvm::AttrsNode { + int num_newaxis; + + TVM_DECLARE_ATTRS(DynExpandDimsAttrs, "relay.attrs.DynExpandDimsAttrs") { + TVM_ATTR_FIELD(num_newaxis) + .describe("Number of axes to be inserted. Should be >= 0.") + .set_lower_bound(0) + .set_default(1); + } +}; // struct ExpandDimsAttrs + /*! \brief Attributes used in concatenate operators */ struct ConcatenateAttrs : public tvm::AttrsNode { int axis; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 688ad8254fa85..f96faffb24f4f 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -37,6 +37,7 @@ #include #include #include + namespace tvm { namespace relay { @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor { * * MixedModeVisitor provides the same recursive API as ExprVisitor, and uses * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions - * of the graph and processes them iteratatively to prevent stack overflows + * of the graph and processes them iteratively to prevent stack overflows */ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { public: diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 912879dc8a4b0..e740776d6d4f4 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -437,13 +437,24 @@ TVM_DLL Pass RelayToTIRTargetHook(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param target_host The target used by the host for compliation. - * \param targets The device type and target pairs for compliation. + * \param target_host The target used by the host for compilation. + * \param targets The device type and target pairs for compilation. * * \return The pass. */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 8830653da88cc..26f4e545deb75 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase { }; /*! - * \brief Array, container representing a contigious sequence of ObjectRefs. + * \brief Array, container representing a contiguous sequence of ObjectRefs. * * Array implements in-place copy-on-write semantics. * diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 3fe4f697bb9ec..977dbfbaaaa18 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -33,6 +33,7 @@ #include #include "./base.h" +#include "./optional.h" namespace tvm { namespace runtime { @@ -1344,7 +1345,14 @@ class Map : public ObjectRef { iterator end() const { return iterator(GetMapNode()->end()); } /*! \return find the key and returns the associated iterator */ iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); } - + /*! \return The value associated with the key, NullOpt if not found */ + Optional Get(const K& key) const { + MapNode::iterator iter = GetMapNode()->find(key); + if (iter == GetMapNode()->end()) { + return NullOptType{}; + } + return DowncastNoCheck(iter->second); + } void erase(const K& key) { CopyOnWrite()->erase(key); } /*! diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h index 664d19818be12..bb9e7ff65adc2 100644 --- a/include/tvm/runtime/container/string.h +++ b/include/tvm/runtime/container/string.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include @@ -149,6 +150,12 @@ class String : public ObjectRef { String(const char* other) // NOLINT(*) : String(std::string(other)) {} + /*! + * \brief Construct a new null object + */ + String(std::nullptr_t) // NOLINT(*) + : ObjectRef(nullptr) {} + /*! * \brief Change the value the reference object points to. * diff --git a/include/tvm/runtime/logging.h b/include/tvm/runtime/logging.h index 704c2d94b8bbd..a951264b97064 100644 --- a/include/tvm/runtime/logging.h +++ b/include/tvm/runtime/logging.h @@ -30,6 +30,7 @@ #define TVM_RUNTIME_LOGGING_H_ #include +#include #include #include @@ -38,6 +39,8 @@ #include #include #include +#include +#include /*! * \brief Macro helper to force a function not to be inlined. @@ -129,15 +132,16 @@ * a = ... * b = ... * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default - * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" + * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' + * // (default behaviour) + * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" * ... * for (int i = 0; i < N; i++) { * a = ... * b = ... * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'break' (non-default - * // behaviour, therefore, has to be explicitly specified) + * // if quit_on_assertion is false, if a==b, continue, otherwise 'break' + * // (non-default behaviour, therefore, has to be explicitly specified) * COND_CHECK_EQ(quit_on_assertion, a, b, break) << "some error message when quiting" * } * } @@ -391,24 +395,131 @@ class LogMessageVoidify { void operator&(std::ostream&) {} }; +/*! \brief Captures the state of the \p TVM_LOG_DEBUG environment flag. */ +class TvmLogDebugSettings { + public: + /*! + * \brief Parses the \p TVM_LOG_DEBUG environment flag as per the specification given by + * \p DebugLoggingEnabled and \p VerboseLoggingEnabled, and caches the result. + */ + inline static const TvmLogDebugSettings& FromFlag() { + // Parse and cache the verbosity level map. + static const auto* settings = + new TvmLogDebugSettings(TvmLogDebugSettings::ParseSpec(std::getenv("TVM_LOG_DEBUG"))); + return *settings; + } + + /*! + * \brief Parses \p opt_spec as per specification for \p TVM_LOG_DEBUG given by + * \p DebugLoggingEnabled and \p VerboseLoggingEnabled. Throws if specification is ill-formed. + */ + static TvmLogDebugSettings ParseSpec(const char* opt_spec); + + /*! + * \brief Implements \p VerboseLoggingEnabled below w.r.t. the already parsed \p TVM_LOG_DEBUG + * environment variable. + */ + inline bool VerboseEnabled(const char* opt_filename, int level) const { + if (opt_filename == nullptr || level < 0 || vlog_level_map_.empty()) { + return false; + } + return VerboseEnabledImpl(opt_filename, level); + } + + /*! \brief Returns true if \p DLOG statements should be executed. */ + bool dlog_enabled() const { return dlog_enabled_; } + + private: + // Slow path for VerboseEnabled. + bool VerboseEnabledImpl(const std::string& filename, int level) const; + + /*! \brief If true, DLOG statements are enabled. */ + bool dlog_enabled_ = false; + /*! + * \brief A map from canonicalized filenames to the maximum VLOG verbosity level for that file. + * May also contain the 'wildcard' entry \p "DEFAULT" representing the level for all other files. + */ + std::unordered_map vlog_level_map_; +}; + +/*! + * \brief Returns true if a DLOG statement is enabled by the \p TVM_LOG_DEBUG environment + * variable. Requires: + * \code + * TVM_LOG_DEBUG=1 + * \endcode + * or a valid setting as described by \p VerboseLoggingEnabled below. + */ // Also from dmlc-core inline bool DebugLoggingEnabled() { static int state = 0; if (state == 0) { - if (auto var = std::getenv("TVM_LOG_DEBUG")) { - if (std::string(var) == "1") { - state = 1; - } else { - state = -1; - } - } else { - // by default hide debug logging. - state = -1; - } + state = TvmLogDebugSettings::FromFlag().dlog_enabled() ? 1 : -1; } return state == 1; } +/*! + * \brief Returns true if a VLOG statement in \p filename is enabled by the \p TVM_LOG_DEBUG + * environment variable for logging at verbosity \p level. Levels should be non-negative. + * + * Filenames are canonicalized to be w.r.t. the src/ dir of the TVM tree. (VLOG's should not + * appear under include/). + * + * To enable file \p relay/foo.cc up to level 2 and \p ir/bar.cc for level 0 only set: + * \code + * TVM_LOG_DEBUG="relay/foo.cc=2;ir/bar.cc=0" + * \endcode + * + * To enable all files up to level 3 but disable \p ir/bar.cc set: + * \code + * TVM_LOG_DEBUG="DEFAULT=2;ir/bar.cc=-1" + * \endcode + * + * Any of these settings will also enable DLOG statements. + */ +inline bool VerboseLoggingEnabled(const char* opt_filename, int level) { + return TvmLogDebugSettings::FromFlag().VerboseEnabled(opt_filename, level); +} + +/*! + * \brief A stack of VLOG context messages. + * + * For use by \p VLOG_CONTEXT macro only. + */ +class VLogContext { + public: + void Push(std::stringstream* stream) { context_stack_.push_back(stream); } + void Pop() { + if (!context_stack_.empty()) { + context_stack_.pop_back(); + } + } + + std::string str() const; + + private: + std::vector context_stack_; +}; + +/*! \brief Thread local \p VLogContext for tracking a stack of VLOG context messages. */ +using ThreadLocalVLogContext = dmlc::ThreadLocalStore; + +/*! + * \brief A RAII class to push/pos a VLOG context message onto the thread-local stack. + * + * For use by \p VLOG_CONTEXT macro only. + */ +class VLogContextEntry { + public: + VLogContextEntry() { ThreadLocalVLogContext::Get()->Push(&sstream_); } + ~VLogContextEntry() { ThreadLocalVLogContext::Get()->Pop(); } + std::ostream& stream() { return sstream_; } + + private: + std::stringstream sstream_; +}; + constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE = "\n" "---------------------------------------------------------------\n" @@ -447,6 +558,7 @@ TVM_CHECK_FUNC(_GE, >=) TVM_CHECK_FUNC(_EQ, ==) TVM_CHECK_FUNC(_NE, !=) #pragma GCC diagnostic pop + } // namespace detail #define LOG(level) LOG_##level @@ -487,6 +599,19 @@ TVM_CHECK_FUNC(_NE, !=) #define DLOG_IF(severity, condition) \ LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition)) +/*! + * \brief If the \p TVM_LOG_DEBUG build flag is enabled, push a context message onto an internal + * stack. All VLOG messages will include this stack in their prefix to help with debugging. E.g.: + * \code + * VLOG_CONTEXT << "my context"; + * VLOG(1) << "my log message"; + * \endcode + * Thread safe. No-op with no execution overhead if the \p TVM_LOG_DEBUG build flag is not enabled. + */ +#define VLOG_CONTEXT \ + ::tvm::runtime::detail::VLogContextEntry vlog_entry_; \ + vlog_entry_.stream() + #else #define LOG_DFATAL LOG_ERROR @@ -494,10 +619,33 @@ TVM_CHECK_FUNC(_NE, !=) #define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) #define DLOG_IF(severity, condition) \ (true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity) +#define VLOG_CONTEXT true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO) #endif +/*! + * \brief If the \p TVM_LOG_DEBUG build flag is enabled, and the containing file has been enabled + * at \p level or greater in the \p TVM_LOG_DEBUG environment variable, then log a message at + * \p INFO severity. + * + * See \p VerboseLoggingEnabled for the format of the \p TVM_LOG_DEBUG environment variable. + * Thread safe. No-op with no execution overhead if the \p TVM_LOG_DEBUG build flag is not enabled. + * No-op with some execution overhead if the \p TVM_LOG_DEBUG build flag is enabled but the + * containing file is not enabled. + */ +#define VLOG(level) \ + DLOG_IF(INFO, ::tvm::runtime::detail::VerboseLoggingEnabled(__FILE__, (level))) \ + << ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str() + #if TVM_LOG_DEBUG +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#else #define DCHECK(x) \ while (false) CHECK(x) #define DCHECK_LT(x, y) \ @@ -512,14 +660,6 @@ TVM_CHECK_FUNC(_NE, !=) while (false) CHECK((x) == (y)) #define DCHECK_NE(x, y) \ while (false) CHECK((x) != (y)) -#else -#define DCHECK(x) CHECK(x) -#define DCHECK_LT(x, y) CHECK((x) < (y)) -#define DCHECK_GT(x, y) CHECK((x) > (y)) -#define DCHECK_LE(x, y) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) CHECK((x) == (y)) -#define DCHECK_NE(x, y) CHECK((x) != (y)) #endif #define TVM_ICHECK_INDENT " " @@ -552,5 +692,6 @@ TVM_CHECK_FUNC(_NE, !=) // Re-export error types using runtime::Error; using runtime::InternalError; + } // namespace tvm #endif // TVM_RUNTIME_LOGGING_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 1127a9ae732cd..a4c285e3dd086 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -38,9 +38,19 @@ #include namespace tvm { -namespace runtime { -typedef DLDevice Device; +// alias DLDevice +using Device = DLDevice; + +// A 'null' device type, does not correspond to any DLDeviceType enum. +// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case +// as a singleton target map indexed by the invalid DLDeviceType '0'. +constexpr DLDeviceType kNullDeviceType = static_cast(0); + +// An 'invalid' device type, does not correspond to any DLDeviceType enum. +constexpr DLDeviceType kInvalidDeviceType = static_cast(-1); + +namespace runtime { /*! * \brief Managed NDArray. @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) { } } // namespace runtime - -// alias Device -using tvm::runtime::Device; - } // namespace tvm namespace std { template <> -struct hash { - std::size_t operator()(const tvm::runtime::Device& dev) const { +struct hash { + std::size_t operator()(const tvm::Device& dev) const { return ((dev.device_id << 8) | dev.device_type); } }; template <> -struct equal_to { - bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const { +struct equal_to { + bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const { return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id); } }; diff --git a/include/tvm/support/parallel_for.h b/include/tvm/support/parallel_for.h index 49a9d4889e337..8bd2e6b825abc 100644 --- a/include/tvm/support/parallel_for.h +++ b/include/tvm/support/parallel_for.h @@ -57,7 +57,7 @@ TVM_DLL std::vector> rr_partitioner(int begin, int end, int ste * }); * \param begin The start index of this parallel loop(inclusive). * \param end The end index of this parallel loop(exclusive). - * \param f The task function to be excuted. Assert to take an int index as input with no output. + * \param f The task function to be executed. Assert to take an int index as input with no output. * \param step The traversal step to the index. * \param partitioner A partition function to split tasks to different threads. Use Round-robin * partitioner by default. @@ -67,6 +67,26 @@ TVM_DLL std::vector> rr_partitioner(int begin, int end, int ste TVM_DLL void parallel_for(int begin, int end, const std::function& f, int step = 1, const PartitionerFuncType partitioner = rr_partitioner); +/*! + * \brief An API to launch fix amount of threads to run the specific functor in parallel. + * Different from `parallel_for`, the partition is determined dynamically on the fly, + * i.e. any time when a thread is idle, it fetches the next task to run. + * The behavior is similar to dynamic scheduling in OpenMP: + * + * \#pragma omp parallel for schedule(dynamic) num_threads(num_threads) + * for (int i = 0; i < 10; i++) { + * a[i] = i; + * } + * + * \param begin The start index of this parallel loop (inclusive). + * \param end The end index of this parallel loop (exclusive). + * \param num_threads The number of threads to be used. + * \param f The task function to be executed. Takes the thread index and the task index as + * input with no output. + * \note `step` support is left for future work. + */ +TVM_DLL void parallel_for_dynamic(int begin, int end, int num_threads, + const std::function& f); } // namespace support } // namespace tvm diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 6b733d074f6a3..fcd2326050edf 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -102,6 +102,16 @@ class LinearCongruentialEngine { *rand_state_ptr_ = rand_state; // Change pointed random state to given random state value. } + /*! + * \brief Fork a new seed for another RNG from current random state. + * \return The forked seed. + */ + TRandState ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a + // different set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return ((*this)() * 32767) % 1999999973; + } + /*! * \brief Construct a random number generator with a random state pointer. * \param rand_state_ptr The random state pointer given in result_type*. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 8ea48dd592d5c..f6741112f269b 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -947,6 +947,7 @@ class ShuffleNode : public PrimExprNode { Array indices; void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); v->Visit("vectors", &vectors); v->Visit("indices", &indices); v->Visit("span", &span); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 66dd5375eaf9a..9f48d9ab9b1f5 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -216,6 +216,7 @@ class ScheduleNode : public runtime::Object { * 1) The loops can't have annotations or thread bindings. * 2) The (i+1)-th loop must be the only child of the i-th loop. * 3) All loops must start with 0. + * 4) The domain of a loop to be fused cannot depend on another loop to be fused. * \param loop_rvs The loops to be fused * \return The new loop after fusion */ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0da8e55be0233..2ae2877b2f92d 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1339,6 +1339,12 @@ constexpr const char* hand_threaded = "hand_threaded"; * if (mask & 2) the write region should be detected. */ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access"; + +/*! + * \brief Mark that the loop should be partitioned. + */ +constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 7470ccc92496a..d6dd094f6a5b7 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -115,6 +115,10 @@ ], ), ), + ( + "importer-paddle", + ("Requirements for the PaddlePaddle importer", ["paddlepaddle"]), + ), ( "importer-pytorch", ( @@ -235,6 +239,7 @@ ("onnx", None), ("onnxruntime", None), ("opencv-python", None), + ("paddlepaddle", None), ("pillow", None), ("progressbar", None), ("psutil", None), diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 450a356aebdf1..297e24d7f7b6d 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -72,16 +72,52 @@ class DataType(ctypes.Structure): DataTypeCode.HANDLE: "handle", DataTypeCode.BFLOAT: "bfloat", } + NUMPY2STR = { + np.dtype(np.bool_): "bool", + np.dtype(np.int8): "int8", + np.dtype(np.int16): "int16", + np.dtype(np.int32): "int32", + np.dtype(np.int64): "int64", + np.dtype(np.uint8): "uint8", + np.dtype(np.uint16): "uint16", + np.dtype(np.uint32): "uint32", + np.dtype(np.uint64): "uint64", + np.dtype(np.float16): "float16", + np.dtype(np.float32): "float32", + np.dtype(np.float64): "float64", + np.dtype(np.float_): "float64", + } + STR2DTYPE = { + "bool": {"type_code": DataTypeCode.UINT, "bits": 1, "lanes": 1}, + "int8": {"type_code": DataTypeCode.INT, "bits": 8, "lanes": 1}, + "int16": {"type_code": DataTypeCode.INT, "bits": 16, "lanes": 1}, + "int32": {"type_code": DataTypeCode.INT, "bits": 32, "lanes": 1}, + "int64": {"type_code": DataTypeCode.INT, "bits": 64, "lanes": 1}, + "uint8": {"type_code": DataTypeCode.UINT, "bits": 8, "lanes": 1}, + "uint16": {"type_code": DataTypeCode.UINT, "bits": 16, "lanes": 1}, + "uint32": {"type_code": DataTypeCode.UINT, "bits": 32, "lanes": 1}, + "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, + "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, + "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, + "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, + } def __init__(self, type_str): super(DataType, self).__init__() - if isinstance(type_str, np.dtype): + numpy_str_map = DataType.NUMPY2STR + if type_str in numpy_str_map: + type_str = numpy_str_map[type_str] + elif isinstance(type_str, np.dtype): type_str = str(type_str) - if type_str == "bool": - self.bits = 1 - self.type_code = DataTypeCode.UINT - self.lanes = 1 + assert isinstance(type_str, str) + + str_dtype_map = DataType.STR2DTYPE + if type_str in str_dtype_map: + dtype_map = str_dtype_map[type_str] + self.bits = dtype_map["bits"] + self.type_code = dtype_map["type_code"] + self.lanes = dtype_map["lanes"] return arr = type_str.split("x") diff --git a/python/tvm/auto_scheduler/measure.py b/python/tvm/auto_scheduler/measure.py index c58aeea57d14b..8c6fd5f1a9492 100644 --- a/python/tvm/auto_scheduler/measure.py +++ b/python/tvm/auto_scheduler/measure.py @@ -909,6 +909,7 @@ def _timed_eval_func( random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True) assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake" assert len(args) == len(build_res.args) + loc_args = [] # pylint: disable=consider-using-enumerate for idx in range(len(args)): if args[idx] is None: @@ -917,11 +918,11 @@ def _timed_eval_func( get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev ) random_fill(empty_array) - args[idx] = empty_array + loc_args.append(empty_array) else: - args[idx] = ndarray.array(args[idx], dev) + loc_args.append(ndarray.array(args[idx], dev)) dev.sync() - costs = time_f(*args).results + costs = time_f(*loc_args).results # pylint: disable=broad-except except Exception: costs = (MAX_FLOAT,) @@ -1112,6 +1113,7 @@ def _rpc_run( ), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices" assert len(args) == len(build_res.args) + loc_args = [] # pylint: disable=consider-using-enumerate for idx in range(len(args)): if args[idx] is None: @@ -1120,16 +1122,16 @@ def _rpc_run( get_const_tuple(build_res_arg.shape), build_res_arg.dtype, dev ) random_fill(empty_array) - args[idx] = empty_array + loc_args.append(empty_array) else: - args[idx] = ndarray.array(args[idx], dev) + loc_args.append(ndarray.array(args[idx], dev)) dev.sync() # First run for check that the kernel is correct - func.entry_func(*args) + func.entry_func(*loc_args) dev.sync() - costs = time_f(*args).results + costs = time_f(*loc_args).results # clean up remote files remote.remove(build_res.filename) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 1977d3de5506e..0eacd1a1f6677 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -25,7 +25,6 @@ import json import logging import threading -from copy import deepcopy import tvm from tvm import autotvm, transform @@ -50,7 +49,6 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): """Call all TOPI compute to extract auto_scheduler tasks in a Relay program""" # pylint: disable=import-outside-toplevel from tvm import relay - from tvm.relay.backend import graph_executor_codegen # Turn off AutoTVM config not found warnings old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent @@ -64,28 +62,11 @@ def call_all_topi_funcs(mod, params, target, opt_level=3): }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): - try: - # TODO(jwfromm) Remove this once AlterOpLayout bug that mutates - # source module is fixed. Until then, create a clone. - mod_clone = deepcopy(mod) - opt_mod, _ = relay.optimize(mod_clone, target, params) - grc = graph_executor_codegen.GraphExecutorCodegen(None, target) - grc.codegen(opt_mod["main"]) - except tvm.TVMError: - print( - "Get errors with GraphExecutorCodegen for task extraction. " - "Fallback to VMCompiler." - ) - mod_clone = deepcopy(mod) - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - mod_clone = ( - tvm.IRModule.from_expr(mod_clone) - if isinstance(mod_clone, relay.Function) - else mod_clone - ) - compiler.lower(mod_clone, target) + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + mod = tvm.IRModule.from_expr(mod) if isinstance(mod, relay.Function) else mod + compiler.lower(mod, target) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 3dceac1b7ffde..714dd540d3ab6 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -22,7 +22,6 @@ """ import threading import logging -from copy import deepcopy import tvm from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext @@ -49,27 +48,10 @@ def _lower(mod, target, params): grc.codegen(mod["main"]) return - # default case - # Try graph codegen first to extract autotvm tasks. - # If failed to compile, then fallback to use VM compiler. - # TODO: Currently VM compiler is likely to stack overflow for large models. - try: - # TODO(jwfromm) Remove this once AlterOpLayout bug that mutates - # source module is fixed. Until then, create a clone. - mod_clone = deepcopy(mod) - opt_mod, _ = relay.optimize(mod_clone, target, params) - grc = graph_executor_codegen.GraphExecutorCodegen(None, target) - grc.codegen(opt_mod["main"]) - except tvm.TVMError as e: - print( - "Get errors with GraphExecutorCodegen for task extraction. " - "Fallback to VMCompiler. Error details:\n%s" % str(e) - ) - mod_clone = deepcopy(mod) - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod_clone, target=target) + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target=target) def extract_from_program(mod, params, target, target_host=None, ops=None): diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index c4b76ad04b715..c44a2fe4a1360 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -45,29 +45,6 @@ def xcrun(cmd): return out.strip() -def codesign(lib): - """Codesign the shared libary - - This is an required step for library to be loaded in - the app. - - Parameters - ---------- - lib : The path to the library. - """ - if "TVM_IOS_CODESIGN" not in os.environ: - raise RuntimeError("Require environment variable TVM_IOS_CODESIGN " " to be the signature") - signature = os.environ["TVM_IOS_CODESIGN"] - cmd = ["codesign", "--force", "--sign", signature] - cmd += [lib] - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "Codesign error:\n" - msg += py_str(out) - raise RuntimeError(msg) - - def create_dylib(output, objects, arch, sdk="macosx"): """Create dynamic library. @@ -181,95 +158,3 @@ def compile_coreml(model, model_name="main", out_dir="."): raise RuntimeError("Compile failed: %s" % res) return mlmodelc_path - - -class XCodeRPCServer(object): - """Wrapper for RPC server - - Parameters - ---------- - cmd : list of str - The command to run - - lock: FileLock - Lock on the path - """ - - def __init__(self, cmd, lock): - self.proc = subprocess.Popen(cmd) - self.lock = lock - - def join(self): - """Wait server to finish and release its resource""" - self.proc.wait() - self.lock.release() - - -def popen_test_rpc(host, port, key, destination, libs=None, options=None): - """Launch rpc server via xcodebuild test through another process. - - Parameters - ---------- - host : str - The address of RPC proxy host. - - port : int - The port of RPC proxy host - - key : str - The key of the RPC server - - destination : str - Destination device of deployment, as in xcodebuild - - libs : list of str - List of files to be packed into app/Frameworks/tvm - These can be dylibs that can be loaed remoted by RPC. - - options : list of str - Additional options to xcodebuild - - Returns - ------- - proc : Popen - The test rpc server process. - Don't do wait() on proc, since it can terminate normally. - """ - if "TVM_IOS_RPC_ROOT" in os.environ: - rpc_root = os.environ["TVM_IOS_RPC_ROOT"] - else: - curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - rpc_root = os.path.join(curr_path, "../../../apps/ios_rpc") - proj_path = os.path.realpath(os.path.join(rpc_root, "tvmrpc.xcodeproj")) - team_id = os.environ["TVM_IOS_TEAM_ID"] - if not os.path.exists(proj_path): - raise RuntimeError( - "Cannot find tvmrpc.xcodeproj in %s," - + (" please set env TVM_IOS_RPC_ROOT correctly" % rpc_root) - ) - - # Lock the path so only one file can run - lock = utils.filelock(os.path.join(rpc_root, "ios_rpc.lock")) - - with open(os.path.join(rpc_root, "rpc_config.txt"), "w") as fo: - fo.write("%s %d %s\n" % (host, port, key)) - libs = libs if libs else [] - for file_name in libs: - fo.write("%s\n" % file_name) - - cmd = [ - "xcrun", - "xcodebuild", - "-scheme", - "tvmrpc", - "-project", - proj_path, - "-destination", - destination, - ] - if options: - cmd += options - cmd += ["test"] - cmd += [f"DEVELOPMENT_TEAM={team_id}"] - - return XCodeRPCServer(cmd, lock) diff --git a/python/tvm/driver/tvmc/common.py b/python/tvm/driver/tvmc/common.py index 15c09753d46f2..9ef2f6f1fbfac 100644 --- a/python/tvm/driver/tvmc/common.py +++ b/python/tvm/driver/tvmc/common.py @@ -387,7 +387,8 @@ def parse_shape_string(inputs_string): ---------- inputs_string: str A string of the form "input_name:[dim1,dim2,...,dimn] input_name2:[dim1,dim2]" that - indicates the desired shape for specific model inputs. + indicates the desired shape for specific model inputs. Colons and forward slashes + within input_names are supported. Spaces are supported inside of dimension arrays. Returns ------- @@ -396,7 +397,11 @@ def parse_shape_string(inputs_string): """ # Create a regex pattern that extracts each separate input mapping. - pattern = r"(?:\w+\/)?\w+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" + # We want to be able to handle: + # * Spaces inside arrays + # * forward slashes inside names (but not at the beginning or end) + # * colons inside names (but not at the beginning or end) + pattern = r"(?:\w+\/)?[:\w]+\:\s*\[\-?\d+(?:\,\s*\-?\d+)*\]" input_mappings = re.findall(pattern, inputs_string) if not input_mappings: raise argparse.ArgumentTypeError( @@ -408,7 +413,7 @@ def parse_shape_string(inputs_string): # Remove whitespace. mapping = mapping.replace(" ", "") # Split mapping into name and shape. - name, shape_string = mapping.split(":") + name, shape_string = mapping.rsplit(":", 1) # Convert shape string into a list of integers or Anys if negative. shape = [int(x) if int(x) > 0 else relay.Any() for x in shape_string.strip("][").split(",")] # Add parsed mapping to shape dictionary. diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index b1f00b7d1ddec..ba7862378557c 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -25,6 +25,7 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn from tvm.relay.op.contrib.cmsisnn import partition_for_cmsisnn +from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai @@ -58,6 +59,10 @@ "config_key": "relay.ext.ethos-n.options", "pass_pipeline": partition_for_ethosn, }, + "ethos-u": { + "config_key": "relay.ext.ethosu.options", + "pass_pipeline": partition_for_ethosu, + }, "bnns": { "config_key": None, "pass_pipeline": partition_for_bnns, diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index b12194e7e0092..2e280ef20ac37 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -15,4 +15,10 @@ # specific language governing permissions and limitations # under the License. """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" +from . import arg_info +from . import database from . import builder +from . import runner +from . import space_generator +from . import search_strategy +from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py new file mode 100644 index 0000000000000..a56ca86e8cb79 --- /dev/null +++ b/python/tvm/meta_schedule/arg_info.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The argument information""" +from typing import Any, List, Union + +from tvm._ffi import register_object +from tvm.runtime import DataType, Object, ShapeTuple +from tvm.tir import PrimFunc + +from . import _ffi_api +from .utils import _json_de_tvm + + +@register_object("meta_schedule.ArgInfo") +class ArgInfo(Object): + """Argument information""" + + def as_json(self) -> Any: + """Converts the ArgInfo to its corresponding JSON representation.""" + return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any) -> "ArgInfo": + """Parse the argument information from a JSON object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + + Returns + ------- + parsed : ArgInfo + The argument information parsed. + """ + return _ffi_api.ArgInfoFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_prim_func(func: PrimFunc) -> List["ArgInfo"]: + """Extract a list of the argument information from PrimFunc. + + Parameters + ---------- + func : PrimFunc + The PrimFunc to get argument information from. + + Returns + ------- + extracted : List[ArgInfo] + An array of the argument information derived. + """ + return _ffi_api.ArgInfoFromPrimFunc(func) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TensorInfo") +class TensorInfo(ArgInfo): + """Tensor argument information + + Parameters + ---------- + dtype : DataType + The data type of the tensor. + shape : ShapeTuple + The shape of the tensor. + """ + + dtype: DataType + shape: ShapeTuple + + def __init__( + self, + dtype: DataType, + shape: Union[ShapeTuple, List[int]], + ) -> None: + """Constructor + + Parameters + ---------- + dtype : DataType + The data type of the tensor. + shape : ShapeTuple + The shape of the tensor. + """ + if isinstance(shape, ShapeTuple): + shape_tuple = shape + else: + shape_tuple = ShapeTuple(shape) + self.__init_handle_by_constructor__( + _ffi_api.TensorInfo, # type: ignore # pylint: disable=no-member + dtype, + shape_tuple, + ) diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index cefe5ec50cad6..99dfaea56090a 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -48,11 +48,20 @@ class LocalBuilder(PyBuilder): Attributes ---------- T_BUILD : typing._GenericAlias - The signature of the build function `f_build`, which is - `Callable[[IRModule, Target], Module]` + The signature of the function `f_build`, which is + + .. code-block:: python + + def default_build(mod: IRModule, target: Target) -> Module: + ... + T_EXPORT : typing._GenericAlias - The signature of the build function `f_export`, which is - `Callable[[Module], str]` + The signature of the function `f_export`, which is + + .. code-block:: python + + def default_export(mod: Module) -> str: + ... Note ---- diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py new file mode 100644 index 0000000000000..dcd430d39407b --- /dev/null +++ b/python/tvm/meta_schedule/database/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.database package. +The database that stores serialized tuning records and workloads +""" +from .database import Database, PyDatabase, TuningRecord +from .json_database import JSONDatabase diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py new file mode 100644 index 0000000000000..3d05441fe22be --- /dev/null +++ b/python/tvm/meta_schedule/database/database.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tuning record database""" +from typing import Any, List + +from tvm._ffi import register_object +from tvm.ir.module import IRModule +from tvm.runtime import Object +from tvm.target import Target +from tvm.tir.schedule import Trace + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..utils import _json_de_tvm + + +@register_object("meta_schedule.Workload") +class Workload(Object): + """A workload, i.e. an IRModule and its structural hash. + + Parameters + ---------- + mod : IRModule + The workload's IRModule + """ + + mod: IRModule + + def __init__(self, mod: IRModule) -> None: + self.__init_handle_by_constructor__( + _ffi_api.Workload, # type: ignore # pylint: disable=no-member + mod, + ) + + def as_json(self) -> Any: + """Export the workload to a JSON string. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.WorkloadAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any) -> "Workload": + """Create a workload from a json object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.WorkloadFromJSON(json_obj) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.TuningRecord") +class TuningRecord(Object): + """The class of tuning records. + + Parameters + ---------- + trace : tvm.ir.Trace + The trace of the tuning record. + run_secs : List[float] + The run time of the tuning record. + workload : Workload + The workload of the tuning record. + target : Target + The target of the tuning record. + args_info : List[ArgInfo] + The argument information of the tuning record. + """ + + trace: Trace + run_secs: List[float] + workload: Workload + target: Target + args_info: List[ArgInfo] + + def __init__( + self, + trace: Trace, + run_secs: List[float], + workload: Workload, + target: Target, + args_info: List[ArgInfo], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.TuningRecord, # type: ignore # pylint: disable=no-member + trace, + run_secs, + workload, + target, + args_info, + ) + + def as_json(self) -> Any: + """Export the tuning record to a JSON string. + + Returns + ------- + json_str : str + The JSON string exported. + """ + return _json_de_tvm(_ffi_api.TuningRecordAsJSON(self)) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": + """Create a tuning record from a json object. + + Parameters + ---------- + json_obj : Any + The json object to parse. + workload : Workload + The workload. + + Returns + ------- + tuning_record : TuningRecord + The parsed tuning record. + """ + return _ffi_api.TuningRecordFromJSON(json_obj, workload) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.Database") +class Database(Object): + """The abstract database interface.""" + + def commit_workload(self, mod: IRModule) -> Workload: + """Commit a workload to the database if missing. + + Parameters + ---------- + mod : IRModule + The IRModule to be searched for or added. + + Returns + ------- + workload : Workload + The workload corresponding to the given IRModule. + """ + return _ffi_api.DatabaseCommitWorkload(self, mod) # type: ignore # pylint: disable=no-member + + def commit_tuning_record(self, record: TuningRecord) -> None: + """Commit a tuning record to the database. + + Parameters + ---------- + record : TuningRecord + The tuning record to add. + """ + _ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore # pylint: disable=no-member + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + """Get the top K tuning records of given workload from the database. + + Parameters + ---------- + workload : Workload + The workload to be searched for. + top_k : int + The number of top records to get. + + Returns + ------- + top_k_records : List[TuningRecord] + The top K records. + """ + return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member + + def __len__(self) -> int: + """Get the number of records in the database. + + Returns + ------- + num_records : int + The number of records in the database + """ + return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyDatabase") +class PyDatabase(Database): + """An abstract Database with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_commit_workload(mod: IRModule) -> Workload: + return self.commit_workload(mod) + + def f_commit_tuning_record(record: TuningRecord) -> None: + self.commit_tuning_record(record) + + def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]: + return self.get_top_k(workload, top_k) + + def f_size() -> int: + return len(self) + + self.__init_handle_by_constructor__( + _ffi_api.DatabasePyDatabase, # type: ignore # pylint: disable=no-member + f_commit_workload, + f_commit_tuning_record, + f_get_top_k, + f_size, + ) + + def commit_workload(self, mod: IRModule) -> Workload: + raise NotImplementedError + + def commit_tuning_record(self, record: TuningRecord) -> None: + raise NotImplementedError + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py new file mode 100644 index 0000000000000..6897b82d98888 --- /dev/null +++ b/python/tvm/meta_schedule/database/json_database.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The default database that uses a JSON File to store tuning records""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.JSONDatabase") +class JSONDatabase(Database): + """The class of tuning records. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + """ + + path_workload: str + path_tuning_record: str + + def __init__( + self, + path_workload: str, + path_tuning_record: str, + allow_missing: bool = True, + ) -> None: + """Constructor. + + Parameters + ---------- + path_workload : str + The path to the workload table. + path_tuning_record : str + The path to the tuning record table. + allow_missing : bool + Whether to create new file when the given path is not found. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseJSONDatabase, # type: ignore # pylint: disable=no-member + path_workload, + path_tuning_record, + allow_missing, + ) diff --git a/python/tvm/meta_schedule/runner/__init__.py b/python/tvm/meta_schedule/runner/__init__.py new file mode 100644 index 0000000000000..47f4557e1d3a8 --- /dev/null +++ b/python/tvm/meta_schedule/runner/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.runner package. +Meta Schedule runners that runs an artifact either locally or through the RPC interface +""" +from .config import EvaluatorConfig, RPCConfig +from .rpc_runner import RPCRunner +from .runner import PyRunner, Runner, RunnerFuture, RunnerInput, RunnerResult diff --git a/python/tvm/meta_schedule/runner/config.py b/python/tvm/meta_schedule/runner/config.py new file mode 100644 index 0000000000000..712766de99c1a --- /dev/null +++ b/python/tvm/meta_schedule/runner/config.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configurations for measurements in the runner""" +import os +from threading import Thread +from typing import NamedTuple, Optional, Union + +from tvm import rpc + + +class EvaluatorConfig(NamedTuple): + """Config Details of Evaluator + + Parameters + ---------- + number: int + The number of runs. + repeat: int + The number of times to repeat in each run. + min_repeat_ms: int + Minimum repeat time in ms. if the execution latency is too short, + increase the number of runs to the given time (in ms) to reduce the measurement error. + enable_cpu_cache_flush: bool + Whether to flush the cache on CPU. + + Note + ---- + The total number of actual executions is 1+number*repeat because we would warm up 1 time before + actual run. The number of runs would be increased if run time is below min_repeat_ms. + """ + + number: int = 3 + repeat: int = 1 + min_repeat_ms: int = 40 + enable_cpu_cache_flush: bool = False + + @staticmethod + def _normalized(config: Optional["EvaluatorConfig"]) -> "EvaluatorConfig": + if config is None: + return EvaluatorConfig() + config = EvaluatorConfig( + number=config.number, + repeat=config.repeat, + min_repeat_ms=config.min_repeat_ms, + enable_cpu_cache_flush=config.enable_cpu_cache_flush, + ) + return config + + +class RPCConfig(NamedTuple): + """RPC configuration + + Parameters + ---------- + tracker_host: str + Host of the RPC Tracker + tracker_port: int + Port of the RPC Tracker + tracker_key: str + Key of the Tracker + session_timeout_sec: float + Timeout of the RPC session + session_priority: int + Priority of the RPC session + """ + + tracker_host: Optional[str] = None + tracker_port: Union[None, int, str] = None + tracker_key: Optional[str] = None + session_priority: int = 1 + session_timeout_sec: int = 10 + + def _sanity_check(self) -> None: + err_str = ( + "RPCConfig.{0} is not provided. Please provide it explicitly," + "or set environment variable {1}" + ) + if self.tracker_host is None: + raise ValueError(err_str.format("tracker_host", "TVM_TRACKER_HOST")) + if self.tracker_port is None: + raise ValueError(err_str.format("tracker_port", "TVM_TRACKER_PORT")) + if self.tracker_key is None: + raise ValueError(err_str.format("tracker_key", "TVM_TRACKER_KEY")) + + @staticmethod + def _normalized(config: Optional["RPCConfig"]) -> "RPCConfig": + if config is None: + config = RPCConfig() + config = RPCConfig( + tracker_host=config.tracker_host or os.environ.get("TVM_TRACKER_HOST", None), + tracker_port=config.tracker_port or os.environ.get("TVM_TRACKER_PORT", None), + tracker_key=config.tracker_key or os.environ.get("TVM_TRACKER_KEY", None), + session_priority=config.session_priority, + session_timeout_sec=config.session_timeout_sec, + ) + config._sanity_check() # pylint: disable=protected-access + return config + + def connect_tracker(self) -> rpc.TrackerSession: + """Connect to the tracker + + Returns + ------- + tracker : TrackerSession + The connected tracker session + """ + tracker: Optional[rpc.TrackerSession] = None + + def _connect(): + nonlocal tracker + tracker = rpc.connect_tracker(self.tracker_host, self.tracker_port) + + t = Thread(target=_connect) + t.start() + t.join(self.session_timeout_sec) + if t.is_alive() or tracker is None: + raise ValueError( + "Unable to connect to the tracker using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}" + ) + return tracker + + def connect_server(self) -> rpc.RPCSession: + """Connect to the server + + Returns + ------- + session : RPCSession + The connected rpc session + """ + tracker = self.connect_tracker() + session: rpc.RPCSession = tracker.request( + key=self.tracker_key, + priority=self.session_priority, + session_timeout=self.session_timeout_sec, + ) + return session + + def count_num_servers(self, allow_missing=True) -> int: + """Count the number of servers available in the tracker + + Parameters + ---------- + allow_missing : bool + Whether to allow no server to be found. + + Returns + ------- + num_servers : int + The number of servers + """ + tracker = self.connect_tracker() + tracker_summary = tracker.summary() + result: int = 0 + for item in tracker_summary["server_info"]: + _, item_key = item["key"].split(":") + if item_key == self.tracker_key: + result += 1 + if result == 0 and not allow_missing: + raise ValueError( + "Unable to find servers with the specific key using the following configuration:\n" + f" tracker host: {self.tracker_host}\n" + f" tracker port: {self.tracker_port}\n" + f" tracker key: {self.tracker_key}\n" + f" timeout (sec): {self.session_timeout_sec}\n" + "Please check the tracker status via the following command:\n" + " python3 -m tvm.exec.query_rpc_tracker " + f"--host {self.tracker_host} --port {self.tracker_port}\n" + f'and look for key: "{self.tracker_key}"' + ) + return result diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py new file mode 100644 index 0000000000000..d20e1707fcecc --- /dev/null +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -0,0 +1,567 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""RPC Runner""" +import concurrent.futures +from contextlib import contextmanager +import itertools +import os.path as osp +from typing import Any, Callable, Dict, List, Optional, Union + +from tvm.contrib.popen_pool import PopenPoolExecutor +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module, ndarray + +from ..utils import ( + get_global_func_on_rpc_session, + get_global_func_with_default_on_worker, +) +from .config import EvaluatorConfig, RPCConfig +from .runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult + + +class RPCRunnerFuture(RunnerFuture): + """RPC based runner future + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + + future: concurrent.futures.Future + timeout_sec: float + + def __init__(self, future: concurrent.futures.Future, timeout_sec: float) -> None: + """Constructor + + Parameters + ---------- + future: concurrent.futures.Future + The concurrent function to check when the function is done and to return the result. + timeout_sec: float + The timeout in seconds. + """ + super().__init__() + self.future = future + self.timeout_sec = timeout_sec + + def done(self) -> bool: + return self.future.done() + + def result(self) -> RunnerResult: + try: + run_secs: List[float] = self.future.result() + except TimeoutError as exception: + return RunnerResult( + None, + error_msg=f"RPCRunner: Timeout, killed after {self.timeout_sec} seconds", + ) + except Exception as exception: # pylint: disable=broad-except + return RunnerResult( + None, + error_msg="RPCRunner: An exception occurred\n" + str(exception), + ) + return RunnerResult(run_secs, None) + + +T_ARG_INFO_JSON_OBJ = List[Any] # pylint: disable=invalid-name +T_ARG_INFO_JSON_OBJ_LIST = List[T_ARG_INFO_JSON_OBJ] # pylint: disable=invalid-name +T_ARGUMENT = Any # pylint: disable=invalid-name +T_ARGUMENT_LIST = List[T_ARGUMENT] # pylint: disable=invalid-name + + +class RPCRunner(PyRunner): + """RPC based runner + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. TODO(@junrushao1994,@zxybazh): This is not used yet. + alloc_repeat: int + The number of times to repeat the allocation. + f_create_session: Optional[str, Callable] + The function name to create the session or the function itself. + f_upload_module: Optional[str, Callable] + The function name to upload the module or the function itself. + f_alloc_argument: Optional[str, Callable] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Optional[str, Callable] + The function name to run the evaluator or the function itself. + f_cleanup: Optional[str, Callable] + The function name to cleanup the session or the function itself. + pool: PopenPoolExecutor + The popen pool executor. + + Attributes + ---------- + T_CREATE_SESSION : typing._GenericAlias + The signature of the function `f_create_session`, which is: + + .. code-block:: python + + def default_create_session(rpc_config: RPCConfig) -> RPCSession: + ... + + T_UPLOAD_MODULE : typing._GenericAlias + The signature of the function `f_upload_module`, which is: + + .. code-block:: python + + def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, + ) -> Module: + ... + + T_ALLOC_ARGUMENT : typing._GenericAlias + The signature of the function `f_alloc_argument`, which is: + + .. code-block:: python + + def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, + ) -> List[T_ARGUMENT_LIST]: + ... + + T_RUN_EVALUATOR : typing._GenericAlias + The signature of the function `f_run_evaluator`, which is: + + .. code-block:: python + + def default_run_evaluator( + session: RPCSession, + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], + ) -> List[float]: + ... + + T_CLEANUP : typing._GenericAlias + The signature of the function `f_cleanup`, which is: + + .. code-block:: python + + def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], + ) -> None: + ... + """ + + T_CREATE_SESSION = Callable[ + [RPCConfig], # The RPC configuration + RPCSession, # The RPC Session + ] + T_UPLOAD_MODULE = Callable[ + [ + RPCSession, # The RPC Session + str, # local path to the artifact + str, # remote path to the artifact + ], + Module, # the Module opened on the remote + ] + T_ALLOC_ARGUMENT = Callable[ + [ + RPCSession, # The RPC Session + Device, # The device on the remote + T_ARG_INFO_JSON_OBJ_LIST, # The metadata information of the arguments to be allocated + int, # The number of repeated allocations to be done + ], + List[T_ARGUMENT_LIST], # A list of argument lists + ] + T_RUN_EVALUATOR = Callable[ + [ + RPCSession, # The RPC Session + Module, # The Module opened on the remote + Device, # The device on the remote + EvaluatorConfig, # The evaluator configuration + List[T_ARGUMENT_LIST], # A list of argument lists + ], + List[float], # A list of running time + ] + T_CLEANUP = Callable[ + [ + Optional[RPCSession], # The RPC Session to be cleaned up + Optional[str], # remote path to the artifact + ], + None, + ] + + rpc_config: RPCConfig + evaluator_config: EvaluatorConfig + cooldown_sec: float + alloc_repeat: int + + f_create_session: Union[T_CREATE_SESSION, str, None] + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + f_cleanup: Union[T_CLEANUP, str, None] + + pool: PopenPoolExecutor + + def __init__( + self, + rpc_config: Optional[RPCConfig] = None, + evaluator_config: Optional[EvaluatorConfig] = None, + cooldown_sec: float = 0.0, + alloc_repeat: int = 1, + f_create_session: Union[T_CREATE_SESSION, str, None] = None, + f_upload_module: Union[T_UPLOAD_MODULE, str, None] = None, + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] = None, + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] = None, + f_cleanup: Union[T_CLEANUP, str, None] = None, + max_connections: Optional[int] = None, + initializer: Optional[Callable[[], None]] = None, + ) -> None: + """Constructor + + Parameters + ---------- + rpc_config: RPCConfig + The rpc configuration. + evaluator_config: EvaluatorConfig + The evaluator configuration. + cooldown_sec: float + The cooldown in seconds. + alloc_repeat: int + The number of times to random fill the allocation. + f_create_session: Union[T_CREATE_SESSION, str, None] + The function name to create the session or the function itself. + f_upload_module: Union[T_UPLOAD_MODULE, str, None] + The function name to upload the module or the function itself. + f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None] + The function name to allocate the arguments or the function itself. + f_run_evaluator: Union[T_RUN_EVALUATOR, str, None] + The function name to run the evaluator or the function itself. + f_cleanup: Union[T_CLEANUP, str, None] + The function name to cleanup the session or the function itself. + max_connections: Optional[int] + The maximum number of connections. + initializer: Optional[Callable[[], None]] + The initializer function. + """ + super().__init__() + self.rpc_config = RPCConfig._normalized(rpc_config) + self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) + self.cooldown_sec = cooldown_sec + self.alloc_repeat = alloc_repeat + self.f_create_session = f_create_session + self.f_upload_module = f_upload_module + self.f_alloc_argument = f_alloc_argument + self.f_run_evaluator = f_run_evaluator + self.f_cleanup = f_cleanup + + num_servers = self.rpc_config.count_num_servers(allow_missing=False) + if max_connections is None: + max_connections = num_servers + else: + max_connections = min(max_connections, num_servers) + + self.pool = PopenPoolExecutor( + max_workers=max_connections, + timeout=rpc_config.session_timeout_sec, + initializer=initializer, + ) + self._sanity_check() + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + results: List[RunnerFuture] = [] + for runner_input in runner_inputs: + future = RPCRunnerFuture( + future=self.pool.submit( + RPCRunner._worker_func, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + self.rpc_config, + self.evaluator_config, + self.alloc_repeat, + str(runner_input.artifact_path), + str(runner_input.device_type), + tuple(arg_info.as_json() for arg_info in runner_input.args_info), + ), + timeout_sec=self.rpc_config.session_timeout_sec, + ) + results.append(future) + return results + + def _sanity_check(self) -> None: + def _check( + f_create_session, + f_upload_module, + f_alloc_argument, + f_run_evaluator, + f_cleanup, + ) -> None: + get_global_func_with_default_on_worker(name=f_create_session, default=None) + get_global_func_with_default_on_worker(name=f_upload_module, default=None) + get_global_func_with_default_on_worker(name=f_alloc_argument, default=None) + get_global_func_with_default_on_worker(name=f_run_evaluator, default=None) + get_global_func_with_default_on_worker(name=f_cleanup, default=None) + + value = self.pool.submit( + _check, + self.f_create_session, + self.f_upload_module, + self.f_alloc_argument, + self.f_run_evaluator, + self.f_cleanup, + ) + value.result() + + @staticmethod + def _worker_func( + _f_create_session: Union[T_CREATE_SESSION, str, None], + _f_upload_module: Union[T_UPLOAD_MODULE, str, None], + _f_alloc_argument: Union[T_ALLOC_ARGUMENT, str, None], + _f_run_evaluator: Union[T_RUN_EVALUATOR, str, None], + _f_cleanup: Union[T_CLEANUP, str, None], + rpc_config: RPCConfig, + evaluator_config: EvaluatorConfig, + alloc_repeat: int, + artifact_path: str, + device_type: str, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + ) -> List[float]: + # Step 0. Get the registered functions + f_create_session: RPCRunner.T_CREATE_SESSION = get_global_func_with_default_on_worker( + _f_create_session, default_create_session + ) + f_upload_module: RPCRunner.T_UPLOAD_MODULE = get_global_func_with_default_on_worker( + _f_upload_module, default_upload_module + ) + f_alloc_argument: RPCRunner.T_ALLOC_ARGUMENT = get_global_func_with_default_on_worker( + _f_alloc_argument, default_alloc_argument + ) + f_run_evaluator: RPCRunner.T_RUN_EVALUATOR = get_global_func_with_default_on_worker( + _f_run_evaluator, default_run_evaluator + ) + f_cleanup: RPCRunner.T_CLEANUP = get_global_func_with_default_on_worker( + _f_cleanup, default_cleanup + ) + # Managed resources + session: Optional[RPCSession] = None + remote_path: Optional[str] = None + + @contextmanager + def resource_handler(): + try: + yield + finally: + # Step 5. Clean up + f_cleanup(session, remote_path) + + with resource_handler(): + # Step 1. Create session + session = f_create_session(rpc_config) + device = session.device(dev_type=device_type, dev_id=0) + # Step 2. Upload the module + _, remote_path = osp.split(artifact_path) + local_path: str = artifact_path + rt_mod: Module = f_upload_module(session, local_path, remote_path) + # Step 3: Allocate input arguments + repeated_args: List[T_ARGUMENT_LIST] = f_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + # Step 4: Run time_evaluator + costs: List[float] = f_run_evaluator( + session, + rt_mod, + device, + evaluator_config, + repeated_args, + ) + return costs + + +def default_create_session(rpc_config: RPCConfig) -> RPCSession: + """Default function to create the session + + Parameters + ---------- + rpc_config : RPCConfig + The configuration of the RPC session + + Returns + ------- + session : RPCSession + The created rpc session + """ + return rpc_config.connect_server() + + +def default_upload_module( + session: RPCSession, + local_path: str, + remote_path: str, +) -> Module: + """Default function to upload the module + + Parameters + ---------- + session: RPCSession + The session to upload the module + local_path: str + The local path of the module + remote_path: str + The remote path to place the module + + Returns + ------- + rt_mod : Module + The runtime module + """ + session.upload(local_path, remote_path) + rt_mod: Module = session.load_module(remote_path) + return rt_mod + + +def default_alloc_argument( + session: RPCSession, + device: Device, + args_info: T_ARG_INFO_JSON_OBJ_LIST, + alloc_repeat: int, +) -> List[T_ARGUMENT_LIST]: + """Default function to allocate the arguments + + Parameters + ---------- + session: RPCSession + The session to allocate the arguments + device: Device + The device to allocate the arguments + alloc_repeat: int + The number of times to repeat the allocation + args_info: PyArgsInfo + The arguments info + + Returns + ------- + repeated_args: List[Args] + The allocation args + """ + f_random_fill = get_global_func_on_rpc_session( + session, + "tvm.contrib.random.random_fill", + "Please make sure 'USE_RANDOM' is turned ON in the config.cmake on the RPC server.", + ) + + def alloc_tensor(_, dtype, shape) -> ndarray.NDArray: + arg = ndarray.empty(shape=shape, dtype=dtype, device=device) + f_random_fill(arg) + return arg + + def alloc_fail(*arg_info) -> None: + raise NotImplementedError(arg_info) + + dispatcher: Dict[Any, Callable] = { + "TENSOR": alloc_tensor, + None: alloc_fail, + } + + repeated_args: List[T_ARGUMENT_LIST] = [] + for _ in range(alloc_repeat): + args: T_ARGUMENT_LIST = [] + arg_info: T_ARG_INFO_JSON_OBJ + for arg_info in args_info: + arg_type = arg_info[0] + arg: Any = dispatcher.get(arg_type, None)(*arg_info) + args.append(arg) + repeated_args.append(args) + return repeated_args + + +def default_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[T_ARGUMENT_LIST], +) -> List[float]: + """Default function to run the evaluator + + Parameters + ---------- + session: RPCSession + The session to run the evaluator + rt_mod: Module + The runtime module + device: Device + The device to run the evaluator + evaluator_config: EvaluatorConfig + The evaluator config + repeated_args: List[Args] + The repeated arguments + + Returns + ------- + costs: List[float] + The evaluator results + """ + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs + + +def default_cleanup( + session: Optional[RPCSession], + remote_path: Optional[str], +) -> None: + """Default function to clean up the session + + Parameters + ---------- + session: RPCSession + The session to clean up + remote_path: str + The remote path to clean up + """ + if session is not None and remote_path is not None: + session.remove(remote_path) + session.remove(remote_path + ".so") + session.remove("") diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py new file mode 100644 index 0000000000000..9f7be8ea4af48 --- /dev/null +++ b/python/tvm/meta_schedule/runner/runner.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Runners""" +from typing import List, Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from .. import _ffi_api +from ..arg_info import ArgInfo + + +@register_object("meta_schedule.RunnerInput") +class RunnerInput(Object): + """The runner's input + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + + artifact_path: str + device_type: str + args_info: List[ArgInfo] + + def __init__( + self, + artifact_path: str, + device_type: str, + args_info: List[ArgInfo], + ) -> None: + """Constructor + + Parameters + ---------- + artifact_path : str + The path to the built artifact. + device_type : str + The device type. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerInput, # type: ignore # pylint: disable=no-member + artifact_path, + device_type, + args_info, + ) + + +@register_object("meta_schedule.RunnerResult") +class RunnerResult(Object): + """The runner's result + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + + run_secs: Optional[List[float]] + error_msg: Optional[str] + + def __init__( + self, + run_secs: Optional[List[float]], + error_msg: Optional[str], + ) -> None: + """Constructor + + Parameters + ---------- + run_secs : Optional[List[float]] + The run time in seconds. + error_msg : Optional[str] + The error message, if any. + """ + self.__init_handle_by_constructor__( + _ffi_api.RunnerResult, # type: ignore # pylint: disable=no-member + run_secs, + error_msg, + ) + + +@register_object("meta_schedule.RunnerFuture") +class RunnerFuture(Object): + """A class to fetch asynchronous runner's output.""" + + def __init__(self) -> None: + """Constructor""" + + def f_done(): + return self.done() + + def f_result(): + return self.result() + + self.__init_handle_by_constructor__( + _ffi_api.RunnerFuture, # type: ignore # pylint: disable=no-member + f_done, + f_result, + ) + + def done(self) -> bool: + """Check whether the runner has finished.""" + raise NotImplementedError + + def result(self) -> RunnerResult: + """Fetch the runner's output if it is ready.""" + raise NotImplementedError + + +@register_object("meta_schedule.Runner") +class Runner(Object): + """The abstract runner interface""" + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + """Run the built artifact and get runner futures. + + Parameters + ---------- + runner_inputs : List[RunnerInput] + The inputs to the runner. + + Returns + ------- + runner_futures: List[RunnerFuture] + The runner futures. + """ + return _ffi_api.RunnerRun(self, runner_inputs) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyRunner") +class PyRunner(Runner): + """An abstract runner with customized build method on the python-side.""" + + def __init__(self) -> None: + """Constructor""" + + def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return self.run(runner_inputs) + + self.__init_handle_by_constructor__( + _ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member + f_run, + ) + + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py new file mode 100644 index 0000000000000..40f21da0b2d1c --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from .search_strategy import SearchStrategy, PySearchStrategy +from .replay_trace import ReplayTrace diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py new file mode 100644 index 0000000000000..3afdff6de77ee --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayTrace") +class ReplayTrace(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__(self, num_trials_per_iter: int, num_trials_total: int): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.ReplayTrace, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py new file mode 100644 index 0000000000000..72713155c41d7 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Search Strategy""" + +from typing import List, Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..arg_info import ArgInfo +from ..runner import RunnerResult + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.MeasureCandidate") +class MeasureCandidate(Object): + """Measure candidate class. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + + sch: Schedule + args_info: List[ArgInfo] + + def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + """Constructor. + + Parameters + ---------- + sch : Schedule + The schedule to be measured. + args_info : List[ArgInfo] + The argument information. + """ + self.__init_handle_by_constructor__( + _ffi_api.MeasureCandidate, # pylint: disable=no-member + sch, + args_info, + ) + + +@register_object("meta_schedule.SearchStrategy") +class SearchStrategy(Object): + """ + Search strategy is the class that generates the measure candidates. It has to be pre-tuned + before usage and post-tuned after usage. + """ + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the search strategy with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initialization. + """ + _ffi_api.SearchStrategyInitializeWithTuneContext( # pylint: disable=no-member + self, tune_context + ) + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + """Pre-tuning for the search strategy. + + Parameters + ---------- + design_spaces : List[Schedule] + The design spaces for pre-tuning. + """ + _ffi_api.SearchStrategyPreTuning(self, design_spaces) # pylint: disable=no-member + + def post_tuning(self) -> None: + """Post-tuning for the search strategy.""" + _ffi_api.SearchStrategyPostTuning(self) # pylint: disable=no-member + + def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: + """Generate measure candidates from design spaces for measurement. + + Returns + ------- + measure_candidates : Optional[List[IRModule]] + The measure candidates generated, None if finished. + """ + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # pylint: disable=no-member + + def notify_runner_results(self, results: List[RunnerResult]) -> None: + """Update the search strategy with profiling results. + + Parameters + ---------- + results : List[RunnerResult] + The profiling results from the runner. + """ + _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # pylint: disable=no-member + + +@register_object("meta_schedule.PySearchStrategy") +class PySearchStrategy(SearchStrategy): + """An abstract search strategy with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(context: "TuneContext") -> None: + self.initialize_with_tune_context(context) + + def f_pre_tuning(design_spaces: List[Schedule]) -> None: + self.pre_tuning(design_spaces) + + def f_post_tuning() -> None: + self.post_tuning() + + def f_generate_measure_candidates() -> List[MeasureCandidate]: + return self.generate_measure_candidates() + + def f_notify_runner_results(results: List["RunnerResult"]) -> None: + self.notify_runner_results(results) + + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyPySearchStrategy, # pylint: disable=no-member + f_initialize_with_tune_context, + f_pre_tuning, + f_post_tuning, + f_generate_measure_candidates, + f_notify_runner_results, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def pre_tuning(self, design_spaces: List[Schedule]) -> None: + raise NotImplementedError + + def post_tuning(self) -> None: + raise NotImplementedError + + def generate_measure_candidates(self) -> List[MeasureCandidate]: + raise NotImplementedError + + def notify_runner_results(self, results: List["RunnerResult"]) -> None: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py new file mode 100644 index 0000000000000..af759d43b34a6 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.space_generator package. +Meta Schedule design space generators that generates design +space for generation of measure candidates. +""" + +from .space_generator import SpaceGenerator, PySpaceGenerator +from .space_generator_union import SpaceGeneratorUnion +from .schedule_fn import ScheduleFn diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py new file mode 100644 index 0000000000000..64edd9e0bf8c2 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta schedule design space generators that generates design +space via a schedule function. +""" +from typing import TYPE_CHECKING, Callable, List, Union + +from tvm.ir import IRModule +from tvm.ir.container import Array +from tvm.tir.schedule import Schedule + +from .space_generator import PySpaceGenerator + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class ScheduleFn(PySpaceGenerator): + """A design space generator with design spaces specified by a schedule function.""" + + # Multiple cases of schedule functions supported + SCH_FN_TYPE = Union[ + Callable[[IRModule], None], # No output + Callable[[IRModule], Schedule], # Single output + Callable[[IRModule], List[Schedule]], # Multiple outputs + ] + + def __init__(self, sch_fn: SCH_FN_TYPE): + """Constructor. + + Parameters + ---------- + sch_fn : SCH_FN_TYPE + The schedule function. + """ + super().__init__() + self.sch_fn = sch_fn + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the design space generator with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the design space generator. + """ + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + """Generate design spaces given a module. + + Parameters + ---------- + mod : IRModule + The module used for design space generation. + + Returns + ------- + design_spaces : List[Schedule] + The generated design spaces, i.e., schedules. + """ + sch = Schedule(mod) # Make sure the schedule is traced + result = self.sch_fn(sch) # Call the schedule function + if result is None: # Case 1. No output + return [sch] + if isinstance(result, Schedule): # Case 2. Single output + return [result] + if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs + for ret in result: # enumerate the outputs + if not isinstance(ret, Schedule): + raise TypeError( + "Wrong type of element in the list, expected Schedule got " + + f"'{type(ret)}': {ret}" + ) + return result + raise TypeError(f"Unexpected return type {type(result)}: {result}") diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py new file mode 100644 index 0000000000000..798753d913456 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule design space generators that generates design +space for generation of measure candidates. +""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.SpaceGenerator") +class SpaceGenerator(Object): + """The abstract design space generator interface.""" + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the design space generator with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the design space generator. + """ + _ffi_api.SpaceGeneratorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + """Generate design spaces given a module. + + Parameters + ---------- + mod : IRModule + The module used for design space generation. + + Returns + ------- + design_spaces : List[Schedule] + The generated design spaces, i.e., schedules. + """ + return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PySpaceGenerator") +class PySpaceGenerator(SpaceGenerator): + """An abstract design space generator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + def f_generate_design_space(mod: IRModule) -> List[Schedule]: + return self.generate_design_space(mod) + + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_generate_design_space, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py new file mode 100644 index 0000000000000..5541ab0b50267 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Union of meta Schedule design space generators.""" +from typing import List + +from tvm._ffi import register_object + +from .. import _ffi_api +from .space_generator import SpaceGenerator + + +@register_object("meta_schedule.SpaceGeneratorUnion") +class SpaceGeneratorUnion(SpaceGenerator): + """Union of design space generators.""" + + def __init__(self, space_generators: List[SpaceGenerator]): + """Constructor. + + Parameters + ---------- + space_generators : List[SpaceGenerator] + The list of design space generators to be unioned. + """ + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorSpaceGeneratorUnion, # type: ignore # pylint: disable=no-member + space_generators, + ) diff --git a/python/tvm/meta_schedule/testing.py b/python/tvm/meta_schedule/testing.py new file mode 100644 index 0000000000000..8fc095efd8b84 --- /dev/null +++ b/python/tvm/meta_schedule/testing.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilities in meta schedule""" +import time + +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server + + +class LocalRPC: + """A pair of RPC tracker/server running locally + + Parameters + ---------- + tracker_host : str + The host URL of the tracker + tracker_port : int + The port of the tracker + tracker_key: str + The key used in the tracker to refer to a worker + """ + + tracker_host: str + tracker_port: int + tracker_key: str + + def __init__( + self, + tracker_key: str = "key", + silent: bool = False, + no_fork: bool = False, + ) -> None: + self.tracker = Tracker( + silent=silent, + port=9190, + port_end=12345, + ) + time.sleep(0.5) + self.server = Server( + host="0.0.0.0", + is_proxy=False, + tracker_addr=(self.tracker.host, self.tracker.port), + key=tracker_key, + silent=silent, + no_fork=no_fork, + port=9190, + port_end=12345, + ) + time.sleep(0.5) + self.tracker_host = self.tracker.host + self.tracker_port = self.tracker.port + self.tracker_key = tracker_key + + def __enter__(self): + return self + + def __exit__(self, _type, _value, _traceback): + if hasattr(self, "server"): + del self.server + if hasattr(self, "tracker"): + del self.tracker diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py new file mode 100644 index 0000000000000..9c41b4d575dae --- /dev/null +++ b/python/tvm/meta_schedule/tune_context.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule tuning context.""" + +from typing import Optional, TYPE_CHECKING + +from tvm import IRModule +from tvm._ffi import register_object +from tvm.meta_schedule.utils import cpu_count +from tvm.runtime import Object +from tvm.target import Target + +from . import _ffi_api + +if TYPE_CHECKING: + from .space_generator import SpaceGenerator + + +@register_object("meta_schedule.TuneContext") +class TuneContext(Object): + """ + The tune context class is designed to contain all resources for a tuning task. + + Different tuning tasks are separated in different TuneContext classes, but different classes in + the same task can interact with each other through tune context. Most classes have a function + to initialize with a tune context. + + Parameters + ---------- + mod : Optional[IRModule] = None + The workload to be optimized. + target : Optional[Target] = None + The target to be optimized for. + task_name : Optional[str] = None + The name of the tuning task. + rand_state : int = -1 + The random state. + Need to be in integer in [1, 2^31-1], -1 means using random number. + num_threads : int = None + The number of threads to be used, None means using the logical cpu count. + + Note + ---- + In most cases, mod and target should be available in the tuning context. They are "Optional" + because we allow the user to customize the tuning context, along with other classes, sometimes + without mod and target. E.g., we can have a stand alone search strategy that generates measure + candidates without initializing with the tune context. + """ + + mod: Optional[IRModule] + target: Optional[Target] + task_name: Optional[str] + rand_state: int + num_threads: int + + def __init__( + self, + mod: Optional[IRModule] = None, + target: Optional[Target] = None, + space_generator: Optional["SpaceGenerator"] = None, + task_name: Optional[str] = None, + rand_state: int = -1, + num_threads: Optional[int] = None, + ): + """Constructor. + + Parameters + ---------- + mod : Optional[IRModule] = None + The workload to be optimized. + target : Optional[Target] = None + The target to be optimized for. + space_generator : Optional[SpaceGenerator] = None + The design space generator. + task_name : Optional[str] = None + The name of the tuning task. + rand_state : int = -1 + The random state. + Need to be in integer in [1, 2^31-1], -1 means using random number. + num_threads : Optional[int] = None + The number of threads to be used, None means using the logical cpu count. + """ + if num_threads is None: + num_threads = cpu_count() + + self.__init_handle_by_constructor__( + _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member + mod, + target, + space_generator, + task_name, + rand_state, + num_threads, + ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 74f93e86f506d..5f536994a9fd2 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,14 +15,18 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +import json import os import shutil -from typing import Callable, Union +from typing import Any, Callable, List, Optional, Union import psutil - from tvm._ffi import get_global_func, register_func from tvm.error import TVMError +from tvm.ir import Array, Map +from tvm.rpc import RPCSession +from tvm.runtime import PackedFunc, String +from tvm.tir import FloatImm, IntImm @register_func("meta_schedule.cpu_count") @@ -91,7 +95,91 @@ def get_global_func_with_default_on_worker( ) from error +def get_global_func_on_rpc_session( + session: RPCSession, + name: str, + extra_error_msg: Optional[str] = None, +) -> PackedFunc: + """Get a PackedFunc from the global registry from an RPCSession. + + Parameters + ---------- + session : RPCSession + The RPCSession to be retrieved from + name : str + The name of the PackedFunc + extra_error_msg : Optional[str] + Extra information to provide in the error message + + Returns + ------- + result : PackedFunc + The result + """ + try: + result = session.get_function(name) + except AttributeError as error: + error_msg = f'Unable to find function "{name}" on the remote RPC server.' + if extra_error_msg: + error_msg = f"{error_msg} {extra_error_msg}" + raise AttributeError(error_msg) from error + return result + + @register_func("meta_schedule.remove_build_dir") def remove_build_dir(artifact_path: str) -> None: """Clean up the build directory""" shutil.rmtree(os.path.dirname(artifact_path)) + + +def _json_de_tvm(obj: Any) -> Any: + """Unpack a TVM nested container to a JSON object in python. + + Parameters + ---------- + obj : Any + The TVM nested container to be unpacked. + + Returns + ------- + result : Any + The unpacked json object. + """ + if obj is None: + return None + if isinstance(obj, (int, float)): + return obj + if isinstance(obj, (IntImm, FloatImm)): + return obj.value + if isinstance(obj, (str, String)): + return str(obj) + if isinstance(obj, Array): + return [_json_de_tvm(i) for i in obj] + if isinstance(obj, Map): + return {_json_de_tvm(k): _json_de_tvm(v) for k, v in obj.items()} + raise TypeError("Not supported type: " + str(type(obj))) + + +@register_func("meta_schedule.json_obj2str") +def json_obj2str(json_obj: Any) -> str: + json_obj = _json_de_tvm(json_obj) + return json.dumps(json_obj) + + +@register_func("meta_schedule.batch_json_str2obj") +def batch_json_str2obj(json_strs: List[str]) -> List[Any]: + """Covert a list of JSON strings to a list of json objects. + Parameters + ---------- + json_strs : List[str] + The list of JSON strings + Returns + ------- + result : List[Any] + The list of json objects + """ + return [ + json.loads(json_str) + for json_str in map(str.strip, json_strs) + if json_str and (not json_str.startswith("#")) and (not json_str.startswith("//")) + ] diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py index d9961e9de3f94..5a4841f39f7cb 100644 --- a/python/tvm/micro/interface_api.py +++ b/python/tvm/micro/interface_api.py @@ -57,9 +57,12 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): metadata_header = os.path.join(output_path, f"{mangled_name}.h") with open(metadata_header, "w") as header_file: header_file.write( - "#include \n" f"#ifndef {mangled_name.upper()}_H_\n" - f"#define {mangled_name.upper()}_H_\n" + f"#define {mangled_name.upper()}_H_\n\n" + "#include \n\n" + "#ifdef __cplusplus\n" + 'extern "C" {\n' + "#endif\n\n" ) _emit_brief(header_file, module_name, "Input tensor pointers") @@ -91,6 +94,8 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): ");\n" ) - header_file.write(f"#endif // {mangled_name.upper()}_H_\n") + header_file.write( + "\n#ifdef __cplusplus\n}\n#endif\n\n" f"#endif // {mangled_name.upper()}_H_\n" + ) return metadata_header diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index 8a62c9b5f9ba1..d1a36ac79d640 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -18,7 +18,7 @@ """Defines glue wrappers around the Project API which mate to TVM interfaces.""" import pathlib -import typing +from typing import Union from .. import __version__ from ..contrib import utils @@ -64,7 +64,7 @@ class GeneratedProject: """Defines a glue interface to interact with a generated project through the API server.""" @classmethod - def from_directory(cls, project_dir: typing.Union[pathlib.Path, str], options: dict): + def from_directory(cls, project_dir: Union[pathlib.Path, str], options: dict): return cls(client.instantiate_from_dir(project_dir), options) def __init__(self, api_client, options): @@ -101,7 +101,17 @@ def __init__(self, api_client): if not self._info["is_template"]: raise NotATemplateProjectError() + def _check_project_options(self, options: dict): + """Check if options are valid ProjectOptions""" + available_options = [option["name"] for option in self.info()["project_options"]] + if options and not set(options.keys()).issubset(available_options): + raise ValueError( + f"""options:{list(options)} include non valid ProjectOptions. + Here is a list of available options:{list(available_options)}.""" + ) + def generate_project_from_mlf(self, model_library_format_path, project_dir, options): + self._check_project_options(options) self._api_client.generate_project( model_library_format_path=str(model_library_format_path), standalone_crt_dir=get_standalone_crt_dir(), @@ -124,9 +134,9 @@ def generate_project(self, graph_executor_factory, project_dir, options): def generate_project( - template_project_dir: typing.Union[pathlib.Path, str], + template_project_dir: Union[pathlib.Path, str], module: ExportableModule, - generated_project_dir: typing.Union[pathlib.Path, str], + generated_project_dir: Union[pathlib.Path, str], options: dict = None, ): """Generate a project for an embedded platform that contains the given model. @@ -154,3 +164,36 @@ def generate_project( """ template = TemplateProject.from_directory(str(template_project_dir)) return template.generate_project(module, str(generated_project_dir), options) + + +def generate_project_from_mlf( + template_project_dir: Union[pathlib.Path, str], + project_dir: Union[pathlib.Path, str], + mlf_path: Union[pathlib.Path, str], + options: dict, +): + """Generate a project from a platform template and an existing Model Library Format archive. + + Parameters + ---------- + template_project_path : pathlib.Path or str + Path to a template project containing a microTVM Project API server. + + project_dir : pathlib.Path or str + Path to a directory where the project will be created. + + mlf_path : pathlib.Path or str + Path to the Model Library Format archive that will be used when creating + the new project. + + options : dict + Project API options given to the microTVM API server for the specified platform. + + Returns + ------- + GeneratedProject : + A class that wraps the generated project and which can be used to further interact with it. + """ + + template = TemplateProject.from_directory(str(template_project_dir)) + return template.generate_project_from_mlf(str(mlf_path), str(project_dir), options) diff --git a/python/tvm/micro/project_api/client.py b/python/tvm/micro/project_api/client.py index ac8ff629a7185..f1eb115cfbbed 100644 --- a/python/tvm/micro/project_api/client.py +++ b/python/tvm/micro/project_api/client.py @@ -205,7 +205,6 @@ def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bo """Launch server located in project_dir, and instantiate a Project API Client connected to it.""" args = None - project_dir = pathlib.Path(project_dir) python_script = project_dir / SERVER_PYTHON_FILENAME diff --git a/python/tvm/parser/__init__.py b/python/tvm/parser/__init__.py index 60fcddb17f08b..d75ad16ebab2c 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/parser/__init__.py @@ -26,8 +26,10 @@ def add(self, name, content): return _ffi.get_global_func("SourceMapAdd")(self, name, content) -def parse(source, source_name="from_string"): - return _ffi_api.ParseModule(source_name, source) +def parse(source, source_name="from_string", init_module=None, init_meta_table=None): + if init_meta_table is None: + init_meta_table = {} + return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table) def parse_expr(source): diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index 2b424ebb5dec0..ed04c202d8af4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -19,6 +19,6 @@ from . import legalize from . import preprocess from . import errors +from . import codegen from . import vela_api from . import tir_to_cs_translator -from .util import partition_for_ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py new file mode 100644 index 0000000000000..e821ea8bf0c4b --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Codegen for Arm(R) Ethos(TM)-U""" +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants +from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator +from tvm.relay.backend.contrib.ethosu import util + + +@tvm._ffi.register_func("relay.ext.ethosu.constant_updater") +def constant_updater(expr, symbol): # pylint: disable=unused-argument + """ + We dont want the build process to extract constants to be loaded in + the runtime as we are embedding them inside the C runtime.Module. + """ + return dict() + + +@tvm._ffi.register_func("relay.ext.ethosu") +def ethosu_compiler(ref): + """Main function to a compile a given relay function of + NPU compatible operators to generated command stream. + Such generated command stream would be loaded to the runtime + module that interfaces with NPU driver. + """ + assert isinstance(ref, tvm.ir.function.BaseFunc) + func_name = ref.attrs["global_symbol"] + # There should only be a single input + assert len(ref.params) == 1 + input_size = util.calculate_size_bytes(ref.params[0]) + output_size = util.calculate_size_bytes(ref.body) + cmms, encoded_constants, scratch_size = _compile(ref) + ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethosu.create") + return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size, input_size, output_size) + + +def _compile(ext_func): + """ + This is the main wrapper that accepts an external + relay function and runs all the passes to lower it down + to command stream + Parameters + ---------- + ext_func : tvm.relay.function.Function + The partitioned relay function + Returns + ------- + cs : str + An hex string of the bytes of command stream + encoded_constants : str + An hex string of the bytes that includes concat'd + encoded weights, encoded biases and scales. + scratch_size : int + The size of the scratch buffer needed. + """ + mod = tvm.IRModule() + mod["main"] = ext_func + mod = LegalizeEthosU()(mod) + mod = relay.transform.InferType()(mod) + # We are currently using copy_constants scheduler In the long run, + # this should be a single intelligent and a composite scheduler + # that can perform scheduling based on user inputs such as + # scratch memory size. + tir_mod, params = lower_to_tir(mod["main"], copy_constants()) + cmms, encoded_constants, scratch_size = tir_to_cs_translator.translate(tir_mod, params) + return cmms, encoded_constants, scratch_size diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 82b7f1e68ceee..fd58da803623c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -221,3 +221,9 @@ def transform_module( mod = LegalizeSplit()(mod) mod = LegalizeEthosUConv2D()(mod) return mod + + def __call__(self, *args, **kwargs): + # pylint is unable figure out the decorated + # class is callable, thus adding this to + # suppress the warning. + pass diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index ce9abcbd683d4..4b28dc5b191e9 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -25,6 +25,8 @@ import ethosu.vela.api as vapi # type: ignore import tvm +from tvm.tir import stmt_functor +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu.tir import spec @@ -39,6 +41,14 @@ class BufferType(Enum): output = auto() +_REGION_MAP = { + BufferType.constant: 0, + BufferType.scratch: 1, + BufferType.input: 3, + BufferType.output: 4, +} + + class BufferInfo(NamedTuple): """A data structure to hold metadata of the buffer""" @@ -49,6 +59,72 @@ class BufferInfo(NamedTuple): btype: BufferType +def translate(tir_module, params): + """This will take an tir module for the NPU + and compile to command stream + + Parameters + ---------- + tir_module : tvm.IRModule + The TIR module containing ethosu extern calls + params : dict + A dictionary containing TIR primfunc argument ordering + idx to constant NDArray map + accel_type : ethosu.vela.api.NpuAccelerator + the accelerator variant the tir module needs to compiled to + + Returns + ------- + cs : str + An hex string of the bytes of command stream + encoded_constants : str + An hex string of the bytes that includes concat'd + encoded weights, encoded biases and scales. + scratch_size : int + The size of the scratch buffer needed. + """ + + buffer_info = extract_buffer_info(tir_module, params) + extern_calls = extract_extern_calls(tir_module) + _npu_ops = list() + for extern_call in extern_calls: + _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) + _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) + target_accel_type = vela_api.get_target_accel_type() + cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_type) + payload = vapi.npu_create_driver_payload(cmds, target_accel_type) + hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex() + return payload.hex(), hex_value, scratch_size + + +def extract_extern_calls(mod): + """This function will obtain all extern + calls from a TIR module + Parameters + ---------- + mod : tvm.IRModule + The TIR Module for NPU + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + extern_calls = list() + + def populate_extern_calls(stmt): + if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": + extern_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) + return extern_calls + + def extract_buffer_info(mod, param_dict): """ This function is to read the tvm.IRModule that @@ -101,6 +177,156 @@ def populate_allocate_buffer_info(stmt): return buffer_info +def assign_addresses(buffer_info, npu_ops): + """This function will assign addresses to tensors + within two buffers : scratch and constants. + The scratch is the buffer created to hold all intermediary data + The constants is the buffer created via unifying all the constant data + (post-encoding). + Parameters + ---------- + buffer_info : dict + This is the dictionary obtained via calling extract_buffer_info. + The key is the buffer name to BufferInfo + npu_ops : list + A list of Vela NpuOps with tir.Loads for addresses + Returns + ------- + npu_ops : list + A list of Vela NpuOps with addesses within scratch and constant buffers + constant_tensor : NDArray + A unified constant data array of uint8 as the constant buffer + scratch_size : int + The size of the scratch tensor. + """ + + def replace_npu_fm_with_address(npu_fm): + assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load) + # We currently does not support tiles + # Change this when tiles are needed + # (i.e. when using rolling buffers) + assert npu_fm.tiles.addresses[1:] == [0, 0, 0] + npu_fm.tiles.addresses[1:] = [0, 0, 0] + buffer = npu_fm.tiles.addresses[0].buffer_var + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + npu_fm.tiles.addresses[0] = address + npu_fm.region = _REGION_MAP[buffer_type] + return npu_fm + + def replace_npu_address_range_with_address(npu_addr_range): + assert isinstance(npu_addr_range.address, tvm.tir.Load) + buffer = npu_addr_range.address.buffer_var + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) + + def replace_tir_loads(npu_object): + if isinstance(npu_object, vapi.NpuFeatureMap): + return replace_npu_fm_with_address(npu_object) + if isinstance(npu_object, vapi.NpuAddressRange): + return replace_npu_address_range_with_address(npu_object) + return npu_object + + def classify_io(buffer): + for _npu_op in npu_ops: + if issubclass(type(_npu_op), vapi.NpuBlockOperation): + if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer: + return BufferType.input + if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer_var == buffer: + return BufferType.input + if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer_var == buffer: + return BufferType.output + + raise ValueError(f"Unused IO : {buffer} in tir module.") + + scratch_size = 0 + constant_tensor = None + buffer_addresses = dict() + for _buffer, info in buffer_info.items(): + if info.values is not None: + assert np.dtype(info.dtype) == np.uint8 + assert info.btype == BufferType.constant + assert len(info.shape) == 1 + if constant_tensor is None: + buffer_addresses[_buffer] = (0, info.btype) + assert info.values.dtype == np.uint8 + size_in_bytes = info.values.size + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + constant_tensor = np.resize(info.values, size_in_bytes) + else: + buffer_addresses[_buffer] = (constant_tensor.size, info.btype) + assert info.values.dtype == np.uint8 + size_in_bytes = info.values.size + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes)) + else: + size_in_bytes = int( + (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) + ) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + if info.btype == BufferType.input_or_output: + buffer_type = classify_io(_buffer) + assert buffer_type in (BufferType.input, BufferType.output) + address = 0 + buffer_addresses[_buffer] = (address, buffer_type) + else: + assert info.btype == BufferType.scratch + address = scratch_size + scratch_size += size_in_bytes + buffer_addresses[_buffer] = (address, info.btype) + + for npu_op in npu_ops: + for attr_name, attr in npu_op.__dict__.items(): + if isinstance(attr, list): + new_attr = list() + for attr_ in attr: + new_attr.append(replace_tir_loads(attr_)) + setattr(npu_op, attr_name, new_attr) + else: + setattr(npu_op, attr_name, replace_tir_loads(attr)) + + return npu_ops, constant_tensor, scratch_size + + +def translate_ethosu_tir_extern_call(tir_extern_call): + """This is a dispatcher function to dispatch + correct translation call depending on the extern call's + first argument""" + supported_extern_calls = { + "ethosu_conv2d": translate_ethosu_conv2d, + "ethosu_copy": translate_ethosu_copy, + } + ext_call_type = tir_extern_call.args[0].value + assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" + npu_op = supported_extern_calls[ext_call_type](tir_extern_call) + # Some conversions return additional outputs + # if they are needed, the caller should use the function directly + if isinstance(npu_op, tuple): + return npu_op[0] + return npu_op + + +def translate_ethosu_copy(tir_extern_call): + """This function will translate a tir ethosu_copy extern_call + as produced by Relay to TIR compilation. + Parameters + ---------- + tir_extern_call : tvm.tir.Call + + Returns + ------- + ethosu.vela.api.NpuDmaOperation + The vela object containing the params of ethosu_copy + """ + # We skip the first element as it is the extern_call function name + serial_object = spec.create_serial_object(spec.SerialCopy, tir_extern_call.args[1:]) + return _create_npu_dma_op(serial_object) + + def _convert_clip_bounds(npu_op): """ This function will convert the min and max value @@ -330,3 +556,21 @@ def _create_npu_resampling_mode( mode = str(mode.value) assert mode in mode_map.keys() return mode_map[mode] + + +def _create_npu_dma_op(serial_copy): + """This is a helper function to capture the list of arguments + to create a NpuDmaOperation object""" + src = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_copy.read_address, + length=int(serial_copy.length.value), + ) + dest = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_copy.write_address, + length=int(serial_copy.length.value), + ) + return vapi.NpuDmaOperation(src, dest) diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 0919d3fe7a5f7..ee47e4abd42bd 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -23,13 +23,11 @@ from inspect import signature from enum import Enum -from typing import Union, Tuple, Dict, Optional +from typing import Union, Tuple import numpy as np # type: ignore import tvm # type: ignore from tvm import relay -from tvm.relay.build_module import bind_params_by_name # type: ignore -from tvm.relay.backend.contrib.ethosu import preprocess # type: ignore class QConv2DArgs(Enum): @@ -145,41 +143,6 @@ def get_accelerator_config(): return compiler_attrs.accelerator_config -# pylint: disable=unused-argument -def partition_for_ethosu( - mod: tvm.ir.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None, **opts -): - """This helper function partition the relay graph as produced by the - relay frontend for a given model into external functions - to be presented to the codegen. - - Parameters - ---------- - mod : tvm.ir.IRModule - The IRModule that gets generated from a relay frontend - params : Optional[Dict[str, tvm.runtime.NDArray]] - Constant input parameters. - - Returns - ------- - mod : IRModule - The partitioned IRModule with external global functions - """ - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - - pattern = relay.op.contrib.get_pattern_table("ethosu") - mod = relay.transform.InferType()(mod) - mod = relay.transform.MergeComposite(pattern)(mod) - mod = relay.transform.AnnotateTarget("ethosu")(mod) - mod = relay.transform.MergeCompilerRegions()(mod) - mod = relay.transform.InferType()(mod) - mod = relay.transform.PartitionGraph()(mod) - mod = relay.transform.InferType()(mod) - mod = preprocess.preprocess_ext_io()(mod) - return mod - - def get_arg_count(func): """Helper function to get the number of arguments in a python function""" @@ -197,3 +160,15 @@ def get_dim_value(layout: str, dim: int): if dim_char == dim: return idx return None + + +def calculate_size_bytes(expr): + """This is a helper function to calculate the number + of bytes required to hold the tensor/relay.expr""" + try: + type_info = np.iinfo(expr.checked_type.dtype) + except ValueError: + type_info = np.finfo(expr.checked_type.dtype) + element_size = type_info.bits // 8 + elements = np.prod(list(expr.checked_type.shape)) + return element_size * elements diff --git a/python/tvm/relay/backend/name_transforms.py b/python/tvm/relay/backend/name_transforms.py new file mode 100644 index 0000000000000..04a7a425bdf13 --- /dev/null +++ b/python/tvm/relay/backend/name_transforms.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Name transformation functions for use in code generation +""" + +from typing import List, Union + +from tvm import TVMError +from . import _backend + + +def to_c_function_style(original_name: str): + """Transform a name to the C function style assuming it is + appropriately constructed using the prefixing functions + + Parameters + ---------- + original_name : str + Original name to transform + """ + return _backend.ToCFunctionStyle(original_name) + + +def to_c_variable_style(original_name: str): + """Transform a name to the C variable style assuming it is + appropriately constructed using the prefixing functions + + Parameters + ---------- + original_name : str + Original name to transform + """ + return _backend.ToCVariableStyle(original_name) + + +def _preprocess_names(names: Union[List[str], str]): + """Preprocesses name strings into format for C++ functions + + Parameters + ---------- + names : Union[List[str], str] + List of names to combine to form a combined name or the name itself + """ + if isinstance(names, str): + if names == "": + raise TVMError("Name is empty") + return [names] + return names + + +def prefix_name(names: Union[List[str], str]): + """Apply TVM-specific prefix to a function name + + Parameters + ---------- + names : Union[List[str], str] + List of names to combine to form a combined name or the name itself + """ + + return _backend.PrefixName(_preprocess_names(names)) + + +def prefix_generated_name(names: Union[List[str], str]): + """Apply generated TVM-specific prefix to a function name + + Parameters + ---------- + names : Union[List[str], str] + List of names to combine to form a combined name or the name itself + """ + + return _backend.PrefixGeneratedName(_preprocess_names(names)) + + +def sanitize_name(original_name: str): + """Sanitize name for output into compiler artifacts + + Parameters + ---------- + original_name : str + Original name to sanitize + """ + return _backend.SanitizeName(original_name) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 4d48f5796aca5..ba2c6b4b54e72 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1008,6 +1008,18 @@ def _impl_v1(cls, inputs, attr, params): return _op.power(out, reci_p) +class GlobalLpPool(OnnxOpConverter): + """Operator converter for GlobalLpPool.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + # TODO: GlobalLpPool does not yet support dynamic shapes + in_shape = infer_shape(inputs[0]) + attr["kernel_shape"] = in_shape[2:] + + return LpPool._impl_v1(inputs, attr, params) + + class Mul(Elemwise): """Operator converter for Multiply.""" @@ -1456,11 +1468,56 @@ class Unsqueeze(OnnxOpConverter): """Operator converter for Unsqueeze.""" @classmethod - def _impl_v1(cls, inputs, attr, params): - axes = sorted(attr["axes"]) + def run_calculation(cls, tensor, axes): + axes = sorted(axes) for axis in axes: - inputs[0] = _op.expand_dims(inputs[0], axis=axis, num_newaxis=1) - return inputs[0] + tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1) + return tensor + + @classmethod + def _impl_v1(cls, inputs, attr, params): + return cls.run_calculation(inputs[0], attr["axes"]) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + if isinstance(inputs[1], _expr.Constant): + constant_axes = list(inputs[1].data.numpy()) + constant_axes = list(map(int, constant_axes)) + return cls.run_calculation(inputs[0], constant_axes) + + rank_input = len(infer_type(inputs[0]).checked_type.shape) + num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0]) + axes = relay.split(inputs[1], num_new_axis).astuple() + result = inputs[0] + + # TODO (AndrewZhaoLuo): investigate performance issues with consecutive + # dynamic expand_dims on non-llvm targets. + for i in range(num_new_axis): + axis = relay.TupleGetItem(axes, i) + # Unpack scalar + axis = relay.reshape(axis, []) + axis = relay.If( + axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64") + ) + result = _op.expand_dims(result, axis) + return result + + +class Squeeze(OnnxOpConverter): + """Operator converter for Squeeze.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get("axes", None) + return _op.squeeze(*inputs, axis) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = inputs[1] + dtype = infer_type(axis).checked_type.dtype + rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype) + axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis) + return _op.squeeze(inputs[0], fold_constant(axis)) class Split(OnnxOpConverter): @@ -1646,6 +1703,26 @@ def _impl_v12(cls, inputs, attr, params): return cls._impl_common(inputs[0], inputs[1], batch_dims) +class Compress(OnnxOpConverter): + """Operator converter for compress""" + + @classmethod + def _impl_v11(cls, inputs, attr, params): + input_tensor, condition_tensor = inputs + + axis = attr.get("axis", None) + + # Change one hot tensor to indices e.g. [0, 1, 1, 0, 1] -> [1, 2, 4] + condition_tensor = _op.reshape(_op.argwhere(condition_tensor), (-1,)) + + if axis is not None: + return _op.take(input_tensor, condition_tensor, axis=axis) + + # if axis is None, flatten input tensor before selection + input_tensor = _op.reshape(input_tensor, (-1,)) + return _op.take(input_tensor, condition_tensor, axis=0) + + class Scatter(OnnxOpConverter): """Operator converter for Scatter.""" @@ -1760,6 +1837,11 @@ class Reduce(OnnxOpConverter): name = "" + @classmethod + def run_calculation(cls, inputs, axis, keepdims): + attr = {"axis": axis, "keepdims": keepdims} + return AttrCvt(cls.name)(inputs, attr) + @classmethod def _impl_v1(cls, inputs, attr, params): if "axes" in attr: @@ -1767,8 +1849,20 @@ def _impl_v1(cls, inputs, attr, params): else: axis_len = len(infer_shape(inputs[0])) axis = list(range(axis_len)) - attr = {"axis": axis, "keepdims": attr.get("keepdims", True)} - return AttrCvt(cls.name)(inputs, attr) + + return cls.run_calculation(inputs, axis, attr.get("keepdims", True)) + + @classmethod + def _impl_v12(cls, inputs, attr, params): + if len(inputs) == 2: + if isinstance(inputs[1], _expr.Constant): + # Get axis and unpack scalar + constant_axis = int(inputs[1].data.numpy()[0]) + return cls.run_calculation([inputs[0]], constant_axis, attr.get("keepdims", True)) + + raise ValueError("Dynamic Reduce is not supported yet!") + + return cls._impl_v1(inputs, attr, params) class ReduceMax(Reduce): @@ -2749,7 +2843,8 @@ def _impl_v12(cls, inputs, attr, params): alpha = _op.const(attr.get("alpha", 1.0), dtype) zero = _op.const(0, dtype) one = _op.const(1, dtype) - return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + out = _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + return out class MaxRoiPool(OnnxOpConverter): @@ -3986,7 +4081,9 @@ def _get_convert_map(opset): "Elu": Elu.get_converter(opset), "Exp": Renamer("exp"), "Greater": Renamer("greater"), + "GreaterOrEqual": Renamer("greater_equal"), "Less": Renamer("less"), + "LessOrEqual": Renamer("less_equal"), "Log": Renamer("log"), "Acos": Renamer("acos"), "Acosh": Renamer("acosh"), @@ -4024,6 +4121,7 @@ def _get_convert_map(opset): # defs/nn "AveragePool": AveragePool.get_converter(opset), "LpPool": LpPool.get_converter(opset), + "GlobalLpPool": GlobalLpPool.get_converter(opset), "MaxPool": MaxPool.get_converter(opset), "MaxUnpool": MaxUnpool.get_converter(opset), "Conv": Conv.get_converter(opset), @@ -4071,12 +4169,13 @@ def _get_convert_map(opset): "Gather": Gather.get_converter(opset), "GatherElements": GatherElements.get_converter(opset), "GatherND": GatherND.get_converter(opset), + "Compress": Compress.get_converter(opset), "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), "ScatterElements": Scatter.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), "EyeLike": EyeLike.get_converter(opset), - "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "Squeeze": Squeeze.get_converter(opset), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), "Shape": Shape.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 39bcfc68e4219..56df39fdaa300 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3713,6 +3713,7 @@ def from_pytorch( custom_convert_map=None, default_dtype="float32", use_parser_friendly_name=False, + keep_quantized_weight=False, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -3745,6 +3746,16 @@ def from_pytorch( so a variable name like "dense.weight" cannot be parsed correctly. Use this option when you want to run the AnnotateSpans pass on the imported module. + keep_quantized_weight : bool + Return quantized weights and bias, rather than float ones. PyTorch stores quantized weights + in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use + a PyTorch function to unpack quantized weights into float32 arrays and quantization + parameters. By default, we return float32 weights and rely on the QNN lowering and the + Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however, + we cannot apply the constant folding pass on a QNN graph. If keep_quantized_weight is True, + we quantize weights in the frontend using a function that is equivalent to + qnn.op.quantize(...) operating on Numpy arrays. + Returns ------- mod : tvm.IRModule @@ -3789,9 +3800,17 @@ def from_pytorch( # For quantized models quantized_ops = set(["aten::quantize_per_tensor", "quantized::linear_dynamic"]) if len(quantized_ops.intersection(set(op_names))) > 0: - weight_quant_params = qnn_torch.get_weight_quant_params(script_module) - qnn_torch.add_input_quant_params_to_op_inputs(graph) - qnn_torch.add_quant_params_to_outputs(outputs, packed_param_map, weight_quant_params) + weight_quant_params = qnn_torch.get_weight_quant_params( + script_module, packed_param_map.values() + ) + input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph) + qnn_torch.add_quant_params_to_outputs( + outputs, + packed_param_map, + weight_quant_params, + input_scales_for_bias, + keep_quantized_weight, + ) qnn_torch.add_quant_params(tvm_params, weight_quant_params) converter.update_convert_map(qnn_torch.convert_map) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 9eafae905bafa..172ab1e41268e 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -32,16 +32,12 @@ class QNNParam: """A placeholder for weight quantization parameters""" - def __init__(self, weight, bias, scale, zero_point, param_key): - param_prefix = param_key[: -len("._packed_params")] - self.weight_var = _expr.var(param_prefix + "_weight", shape=weight.shape) + def __init__(self, weight, bias, scale, zero_point): self.weight = weight if bias is not None: - self.bias_var = _expr.var(param_prefix + "_bias", shape=bias.shape) self.bias = bias.detach().numpy() else: - self.bias_var = None self.bias = None self.scale = _expr.const(scale) @@ -56,13 +52,24 @@ class ConvPackedParam(QNNParam): """ def __init__( - self, weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + self, + weight_np, + bias, + scale, + zero_point, + stride, + padding, + dilation, + groups, + output_padding, ): - super().__init__(weight_np, bias, scale, zero_point, param_name) + super().__init__(weight_np, bias, scale, zero_point) self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups + # Used only for conv_transpose2d + self.output_padding = output_padding def _get_quant_params(qweight): @@ -81,23 +88,32 @@ def _get_quant_params(qweight): return weight_np, scales, 0 -def make_qnn_param(param_name, qweight, bias): +def make_qnn_param(qweight, bias): weight_np, scale, zero_point = _get_quant_params(qweight) - return QNNParam(weight_np, bias, scale, zero_point, param_name) + return QNNParam(weight_np, bias, scale, zero_point) -def make_conv_packed_param(param_name, qweight, bias, packed_params): +def make_conv_packed_param(qweight, bias, packed_params): weight_np, scale, zero_point = _get_quant_params(qweight) stride = packed_params.stride() padding = packed_params.padding() dilation = packed_params.dilation() groups = packed_params.groups() + output_padding = packed_params.output_padding() return ConvPackedParam( - weight_np, bias, scale, zero_point, param_name, stride, padding, dilation, groups + weight_np, + bias, + scale, + zero_point, + stride, + padding, + dilation, + groups, + output_padding, ) -def get_weight_quant_params(script_module): +def get_weight_quant_params(script_module, packed_param_names): """Retrive and unpack weight parameters from quantized modules""" import torch @@ -114,6 +130,9 @@ def filter_func(named_module): key = name + "." + param_name state_dict = m.state_dict() + if key not in packed_param_names: + continue + if len(state_dict) == 0 and not hasattr(m, param_name): # for v1.6 and above # This case seems to happen if a model is serialized @@ -130,31 +149,96 @@ def filter_func(named_module): if "Conv" in m.original_name and len(state_dict) == 0: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_conv_packed_param(key, qweight, bias, packed_params) + quant_params[key] = make_conv_packed_param(qweight, bias, packed_params) elif "Conv" in m.original_name: qweight, bias = torch.ops.quantized.conv2d_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) elif m.original_name == "LinearPackedParams": qweight, bias = torch.ops.quantized.linear_unpack(packed_params) - quant_params[key] = make_qnn_param(key, qweight, bias) + quant_params[key] = make_qnn_param(qweight, bias) return quant_params -def add_quant_params_to_outputs(outputs, packed_param_map, quant_params): +def quantize_numpy(weight, scale, zero_point, out_dtype_np): + iinfo = np.iinfo(out_dtype_np) + clip_min = iinfo.min + clip_max = iinfo.max + if len(scale.shape) > 0: + scale = np.reshape(scale, [weight.shape[0]] + [1] * (len(weight.shape) - 1)) + transformed = zero_point + weight / scale + return np.clip(np.round(transformed), clip_min, clip_max).astype(out_dtype_np) + + +def add_quant_params_to_outputs( + outputs, packed_param_map, quant_params, input_scales_for_bias, keep_quantized_weight=False +): """ Add quant params to outputs so that they can be referenced by other ops later. Weights are quantized here. """ for node_name, packed_param_name in packed_param_map.items(): qparam = quant_params[packed_param_name] - qweight = relay.qnn.op.quantize( - qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 - ) - params = [qweight, qparam.scale, qparam.zero_point, qparam.bias_var] + weight_scale = _get_numpy(qparam.scale) + param_prefix = packed_param_name[: -len("._packed_params")] + + if keep_quantized_weight: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="int8" + ) + qparam.weight = quantize_numpy( + qparam.weight, weight_scale, _get_numpy(qparam.zero_point), np.int8 + ) + qweight = qparam.weight_var + else: + qparam.weight_var = _expr.var( + param_prefix + "_weight", shape=qparam.weight.shape, dtype="float32" + ) + qweight = relay.qnn.op.quantize( + qparam.weight_var, qparam.scale, qparam.zero_point, out_dtype="int8", axis=0 + ) + + if qparam.bias is not None: + float_bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="float32" + ) + if node_name not in input_scales_for_bias: + # This case is for dynamic quantization, where the input activation scale is + # unknown until runtime. + qparam.bias_var = float_bias_var + qbias = qparam.bias_var + elif keep_quantized_weight: + qparam.bias_var = _expr.var( + param_prefix + "_bias", shape=qparam.bias.shape, dtype="int32" + ) + qparam.bias = quantize_numpy( + qparam.bias, input_scales_for_bias[node_name] * weight_scale, 0, np.int32 + ) + qbias = qparam.bias_var + else: + qparam.bias_var = float_bias_var + qbias = relay.qnn.op.quantize( + qparam.bias_var, + _expr.const(input_scales_for_bias[node_name] * weight_scale), + _expr.const(0, "int32"), + out_dtype="int32", + axis=0, + ) + else: + qbias = None + + quant_params[packed_param_name] = qparam + + params = [qweight, qparam.scale, qparam.zero_point, qbias] if isinstance(quant_params[packed_param_name], ConvPackedParam): - params += [qparam.stride, qparam.padding, qparam.dilation, qparam.groups] + params += [ + qparam.stride, + qparam.padding, + qparam.dilation, + qparam.groups, + qparam.output_padding, + ] outputs[node_name] = params @@ -192,6 +276,7 @@ def _get_quant_param_for_input(input_value): "quantized::mul_scalar": (2, 3), "quantized::add_scalar": (2, 3), "quantized::hardswish": (1, 2), + "quantized::conv_transpose2d": qconv_indices, } def dfs(current_node): @@ -362,11 +447,14 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::relu6": 1, "quantized::hardswish": 1, "aten::hardsigmoid": 1, + "quantized::conv_transpose2d": 1, } need_input_quant_param = set(num_quantized_inputs.keys()) need_input_quant_param.add("quantized::cat") + input_scales_for_bias = {} + for node in graph.nodes(): operator = node.kind() if operator not in need_input_quant_param: @@ -401,6 +489,12 @@ def add_input_quant_params_to_op_inputs(graph): node.addInput(scale) node.addInput(zp) + if "conv" in operator or "linear" in operator: + # This is required for quantizing the bias + input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value") + + return input_scales_for_bias + def add_quant_params(params, quant_params): """Add quant parameters to TVM param map""" @@ -478,10 +572,7 @@ def _do_bias_and_requantize( # Instead, the torch way requires rounding of activation at runtime if bias is not None: - qbias = relay.qnn.op.quantize( - bias, requant_input_scale, _expr.const(0, "int32"), out_dtype="int32", axis=0 - ) - requantize_input = _op.nn.bias_add(output, qbias) + requantize_input = _op.nn.bias_add(output, bias) else: requantize_input = output @@ -924,6 +1015,65 @@ def _impl(inputs, _): return _impl +def _quantized_conv_transpose2d(with_relu=False): + def _impl(inputs, _): + # Refer to aten/src/ATen/native/quantized/cpu/qconv.cpp + # Supported in Torch 1.7 or newer + conv_params = inputs[1] + weight = conv_params[0] + weight_scale = conv_params[1] + weight_zero_point = conv_params[2] + bias = conv_params[3] + + strides = conv_params[4] + padding = conv_params[5] + dilation = conv_params[6] + groups = conv_params[7] + output_padding = conv_params[8] + + output_scale = _expr.const(inputs[2]) + output_zero_point = _expr.const(inputs[3]) + + assert len(inputs) == 6, "Input quant params not found in op inputs" + + # These are manually added by add_input_quant_params_to_op_inputs above + # In torch, they are retrieved from QTensor data structure at runtime + input_scale = _expr.const(inputs[4]) + input_zero_point = _expr.const(inputs[5]) + + weight_shape = list(infer_shape(weight)) + + # Swap I and O dims to match shape relay expects for OIHW + weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] + + kernel_size = (weight_shape[2], weight_shape[3]) + out_channels = weight_shape[0] + + conv_out = relay.qnn.op.conv2d_transpose( + inputs[0], + weight, + input_zero_point, + weight_zero_point, + input_scale, + weight_scale, + kernel_size=kernel_size, + dilation=dilation, + strides=strides, + padding=padding, + groups=groups, + channels=out_channels, + output_padding=output_padding, + out_dtype="int32", + kernel_layout="OIHW", + ) + + return _do_bias_and_requantize( + conv_out, bias, input_scale, weight_scale, output_scale, output_zero_point, with_relu + ) + + return _impl + + convert_map = { "aten::quantize_per_tensor": _quantize_per_tensor(), "quantized::conv2d_relu": _quantized_conv2d(with_relu=True), @@ -941,4 +1091,5 @@ def _impl(inputs, _): "quantized::relu6": _relu6(), "quantized::linear_dynamic": _linear_dynamic(), "quantized::hardswish": _hswish(), + "quantized::conv_transpose2d": _quantized_conv_transpose2d(), } diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 4d607e46c97f0..93a1dba233f2f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -775,7 +775,7 @@ def convert_softmax(self, op): assert len(output_tensors) == 1, "output tensors length should be 1" output_tensor = output_tensors[0] - params = {"axis": 1} # 1 is channel + params = {"axis": -1} # -1 is channel in_expr = self.get_expr(input_tensor_idx) # TODO - Naive softmax int8 implementation leads to bad accuracy. Currently, we can @@ -2858,7 +2858,7 @@ def convert_transpose_conv(self, op): # Input (data) Tensor. NHWC layout input_tensor = input_tensors[2] - _, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) + _, _, _, input_c = to_int_list(self.get_tensor_shape(input_tensor)) # Weights tensor. TFLite uses OHWI layout weights_tensor = input_tensors[1] out_channels, kernel_h, kernel_w, in_channels = to_int_list( @@ -2919,8 +2919,9 @@ def convert_transpose_conv(self, op): ), "Output channel in the filter should match to channel in the output_shape" if padding == Padding.SAME: - pad_top, pad_bottom = get_pad_value(input_h, kernel_h, stride_h) - pad_left, pad_right = get_pad_value(input_w, kernel_w, stride_w) + output_h, output_w = output_shape_value[1], output_shape_value[2] + pad_top, pad_bottom = get_pad_value(output_h, kernel_h, stride_h) + pad_left, pad_right = get_pad_value(output_w, kernel_w, stride_w) padding = (pad_top, pad_left, pad_bottom, pad_right) else: padding = (0, 0, 0, 0) diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 809b6369b0854..f5f8870ab0153 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -22,8 +22,16 @@ from .. import op as reg -def on_device(data, device): - """Annotate an expression with a certain device type. +def _device_to_int(device): + if isinstance(device, _Device): + return device.device_type + if isinstance(device, str): + return _nd.device(device).device_type + raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) + + +def on_device(data, device, is_fixed=False): + """Annotates an expression with the device type on which its result should be stored. Parameters ---------- @@ -31,23 +39,45 @@ def on_device(data, device): The expression to be annotated. device : Union[:py:class:`Device`, str] - The device type to annotate. + The device to annotate with. Only the device's type is significant. + + is_fixed : bool + If false (the default), a device_copy + If true, the annotation does not imply a device_copy may be inserted to + reconcile the device of the data argument with the device for the context of the + annotated expression. Returns ------- result : tvm.relay.Expr The annotated expression. """ - if isinstance(device, _Device): - device = device.device_type - elif isinstance(device, str): - device = _nd.device(device).device_type - else: - raise ValueError( - "device is expected to be the type of Device or " - "str, but received %s" % (type(device)) - ) - return _make.on_device(data, device) + return _make.on_device(data, _device_to_int(device), is_fixed) + + +def function_on_device(function, param_devices, result_device): + """Annotates a Relay function with the device types on which its parameters and result should + be stored. + + Parameters + ---------- + function : tvm.relay.Function + The function to be annotated. + + param_devices : Array[Union[:py:class:`Device`, str]] + The devices for each parameter. Only the device types are significant. + + result_device: Union[:py:class:`Device`, str] + The device for the function result. Only the device type is significant. + + Returns + ------- + result : tvm.rleay.Function + The annotated function. + """ + return _make.function_on_device( + function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device) + ) def stop_fusion(data): diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 79bd02db164b1..a2fdc19badab9 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -64,7 +64,6 @@ def _func_wrapper(expr): _register_external_op_helper("nn.dense") _register_external_op_helper("nn.relu") _register_external_op_helper("add") -_register_external_op_helper("subtract") _register_external_op_helper("multiply") diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 0da81101c77bd..85ddfd9a7ec84 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -14,19 +14,51 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=ungrouped-imports """Arm(R) Ethos(TM)-U NPU supported operators.""" -from typing import List, Tuple, Callable +import functools + +from typing import Dict, List, Tuple, Callable, Optional import numpy as np # type: ignore import tvm # type: ignore +from tvm import relay from tvm.relay.expr import Constant # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant # type: ignore -from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore -from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs -from tvm.relay.backend.contrib.ethosu.util import RequantArgs -from tvm.relay.backend.contrib.ethosu.util import get_dim_value -from ethosu.vela import api as vapi # type: ignore +from tvm.relay.build_module import bind_params_by_name # type: ignore + +try: + # As ethos-u-vela package is an optional TVM dependency, we want to lazy load it + # and check whether it is installed or not. + # + # In order to show the appropriate error messages when we try to invoke code that + # rely on imports from ethos-u-vela, we protect them with the decorator @requires_vela + # implemented below. + from ethosu.vela import api as vapi # type: ignore + from tvm.relay.backend.contrib.ethosu import preprocess + from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore + from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs + from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import get_dim_value +except ImportError: + vapi = None + + +def requires_vela(func): + """Decorator to check whether we have the required dependency ethos-u-vela + installed as a python package""" + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not vapi: + raise ImportError( + "The 'ethos-u-vela' python package is required for the Arm(R) Ethos(TM)-U NPU " + "backend. Please install the dependency using your Python package manager." + ) from None + return func(*args, **kwargs) + + return wrapper class TensorParams: @@ -36,6 +68,7 @@ class TensorParams: for the creation of tensors in Vela. """ + @requires_vela def __init__(self, tensor, layout=None, scale=None, zero_point=None): self.tensor = tensor if isinstance(tensor, Constant): @@ -148,6 +181,7 @@ class QnnConv2DParams: padding_bounds = [31, 31, 32, 32] activation_map = {"clip": "CLIP"} + @requires_vela def __init__(self, func_body: tvm.relay.Function): activation = None if str(func_body.op) in self.activation_map.keys(): @@ -247,3 +281,39 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal lambda pat: QnnConv2DParams(pat).is_valid(), ) ] + + +# pylint: disable=unused-argument +@requires_vela +def partition_for_ethosu( + mod: tvm.ir.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None, **opts +): + """This helper function partition the relay graph as produced by the + relay frontend for a given model into external functions + to be presented to the codegen. + + Parameters + ---------- + mod : tvm.ir.IRModule + The IRModule that gets generated from a relay frontend + params : Optional[Dict[str, tvm.runtime.NDArray]] + Constant input parameters. + + Returns + ------- + mod : IRModule + The partitioned IRModule with external global functions + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + pattern = relay.op.contrib.get_pattern_table("ethosu") + mod = relay.transform.InferType()(mod) + mod = relay.transform.MergeComposite(pattern)(mod) + mod = relay.transform.AnnotateTarget("ethosu")(mod) + mod = relay.transform.MergeCompilerRegions()(mod) + mod = relay.transform.InferType()(mod) + mod = relay.transform.PartitionGraph()(mod) + mod = relay.transform.InferType()(mod) + mod = preprocess.preprocess_ext_io()(mod) + return mod diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index de8ee0895462e..c909764319d91 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -20,10 +20,13 @@ from tvm.runtime import convert from tvm.te.hybrid import script + from .. import op as _reg _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") +_reg.register_injective_schedule("dyn.expand_dims") +_reg.register_injective_schedule("dyn.squeeze") _reg.register_broadcast_schedule("dyn.tile") _reg.register_injective_schedule("dyn.one_hot") _reg.register_injective_schedule("dyn.full") @@ -89,6 +92,42 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims): return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] +@script +def _expand_dims_shape_func_input_data(data, axis, ndims, num_newaxis): + out = output_tensor((ndims,), "int64") + + for i in const_range(ndims): + if i < axis: + # We multiply by a check (i < len(data.shape)) to avoid + # a constant folding mechanism leading to an overflow + out[i] = int64(data.shape[i * (i < len(data.shape))]) + elif i - num_newaxis < axis: + out[i] = int64(1) + else: + out[i] = int64( + # We can't use axis in indices as it is not constant but we can + # use negative indices (kind of, have to manually do it) + data.shape[ + (i - num_newaxis) * (i - num_newaxis >= 0) + + (i - num_newaxis + len(data.shape)) * (i - num_newaxis < 0) + ] + ) + + return out + + +@_reg.register_shape_func("dyn.expand_dims", [True, True]) +def dynamic_expand_dims_shape_func(attrs, inputs, out_ndims): + return [ + _expand_dims_shape_func_input_data( + inputs[0], + inputs[1], + out_ndims[0], + convert(attrs.num_newaxis), + ) + ] + + @script def _tile_shape_func(data, reps, ndim, tndim, rndim): out = output_tensor((tndim,), "int64") @@ -220,3 +259,24 @@ def _sparse_to_dense_shape_func(output_shape, ndim): @_reg.register_shape_func("dyn.sparse_to_dense", True) def sparse_to_dense_shape_func(attrs, inputs, out_ndims): return [_sparse_to_dense_shape_func(inputs[3], out_ndims[0])] + + +@script +def _squeeze_shape_func_input_data(data, axis, ndims): + out = output_tensor((ndims,), "int64") + out_i = 0 + for i in const_range(data.shape[0]): + not_in_axis = True + for j in const_range(axis.shape[0]): + if i == axis[j]: + not_in_axis = False + if not_in_axis: + out[out_i] = int64(data[i]) + out_i += 1 + + return out + + +@_reg.register_shape_func("dyn.squeeze", [False, True]) +def dynamic_squeeze_shape_func(attrs, inputs, out_ndims): + return [_squeeze_shape_func_input_data(inputs[0], inputs[1], out_ndims[0])] diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 918c36c200795..da7cbd5cec104 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -183,9 +183,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 8d9c28ba714b1..1453128eeb677 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -69,9 +69,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) KH, KW, CI, CO = get_const_tuple(kernel.shape) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2c299022bd6eb..234e76b11813a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -96,7 +96,7 @@ def expand_dims(data, axis, num_newaxis=1): data : relay.Expr The input data to the operator. - axis : int + axis : Union[int, Expr] The axis at which the input array is expanded. Should lie in range `[-data.ndim - 1, data.ndim]`. If `axis < 0`, it is the first axis inserted; @@ -110,7 +110,13 @@ def expand_dims(data, axis, num_newaxis=1): result : relay.Expr The reshaped result. """ - return _make.expand_dims(data, axis, num_newaxis) + if isinstance(axis, int): + return _make.expand_dims(data, axis, num_newaxis) + if isinstance(axis, Expr): + # TODO (AndrewZhaoLuo): investigate performance issues with consecutive + # dynamic expand_dims on non-llvm targets. + return _dyn_make.expand_dims(data, axis, num_newaxis) + raise ValueError(f"Unknown type for axis: {type(axis)}") def transpose(data, axes=None): @@ -143,7 +149,7 @@ def squeeze(data, axis=None): data : tvm.relay.Expr The input data to the operator. - axis : None or List[int] + axis : None or List[int] or Expr The set of axes to remove. If axis = None, remove all axis of dimensions 1. If any specified axis has dimension that does not equal 1, it is an error. @@ -153,6 +159,10 @@ def squeeze(data, axis=None): result : tvm.relay.Expr The squeezed result. """ + if isinstance(axis, Constant): + axis = list(axis.data.numpy()) + if isinstance(axis, Expr): + return _dyn_make.squeeze(data, axis) return _make.squeeze(data, axis) diff --git a/python/tvm/relay/qnn/op/layout_conversions.py b/python/tvm/relay/qnn/op/layout_conversions.py index a7c90daf36a4e..1a3b1771d6ce3 100644 --- a/python/tvm/relay/qnn/op/layout_conversions.py +++ b/python/tvm/relay/qnn/op/layout_conversions.py @@ -78,3 +78,51 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts): return relay.qnn.op.conv2d(*inputs, **new_attrs) raise ValueError("Layout %s is not yet supported" % desired_data_layout) + + +@reg.register_convert_op_layout("qnn.conv2d_transpose") +def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for QNN conv2d_transpose op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + + assert ( + len(desired_layouts) == 2 + ), "A desired layout is expected for both of qnn.conv2d_transpose's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + + new_attrs = dict(attrs) + new_attrs["data_layout"] = desired_data_layout + + if desired_kernel_layout != "default": + new_attrs["kernel_layout"] = desired_kernel_layout + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + # Handle default kernel layouts + if desired_data_layout == "NCHW": + new_attrs["kernel_layout"] = "OIHW" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + if desired_data_layout == "NHWC": + new_attrs["kernel_layout"] = "HWIO" + return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs) + + raise ValueError("Layout %s is not yet supported" % desired_data_layout) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 0ed75191c40df..1adde9a4a4305 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -101,6 +101,17 @@ def avgpool2d(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("nn.global_avg_pool2d") +def global_avgpool2d(expr, type_map): + """Rewrite a global_avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.global_avg_pool2d(arg) + out = relay.op.cast(out, t.dtype) + return [out, t] + + @register_fake_quantization_to_integer("nn.bias_add") def bias_add(expr, type_map): """Rewrite a bias_add op""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 688422284c0f0..bb91afc061953 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -546,7 +546,7 @@ def MergeCompilerRegions(): def RewriteAnnotatedOps(fallback_device): """Rewrite the annotated program where annotation operators, e.g. - `on_deivce`, mark which device an expression should be scheduled to. + `on_device`, mark which device an expression should be scheduled to. This pass helps heterogeneous execution where different operators may need to be allocated on various devices. @@ -1167,6 +1167,16 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 71563b5082902..b3504dbac506e 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -29,5 +29,5 @@ from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib -from .container import String +from .container import String, ShapeTuple from .params import save_param_dict, load_param_dict diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 27811a963993e..2b9f7f9446baf 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -165,9 +165,18 @@ def copyfrom(self, source_array): source_array.shape, shape ) ) - source_array = np.ascontiguousarray( - source_array, dtype="uint16" if dtype == "bfloat16" else dtype + numpy_str_map = DataType.NUMPY2STR + np_dtype_str = ( + numpy_str_map[source_array.dtype] + if source_array.dtype in numpy_str_map + else str(source_array.dtype) ) + if (not source_array.flags["C_CONTIGUOUS"]) or ( + dtype == "bfloat16" or dtype != np_dtype_str + ): + source_array = np.ascontiguousarray( + source_array, dtype="uint16" if dtype == "bfloat16" else dtype + ) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index af2f5d8572930..4e5826f5b2a2c 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -286,7 +286,7 @@ def intel_graphics(model="unknown", options=None): "atsamd51": ["-mcpu=cortex-m4"], "cxd5602gg": ["-mcpu=cortex-m4"], "esp32": [], - "imxrt1060": ["-mcpu=cortex-m7"], + "imxrt10xx": ["-mcpu=cortex-m7"], "mps2_an521": ["-mcpu=cortex-m33"], "nrf52840": ["-mcpu=cortex-m4"], "nrf5340dk": ["-mcpu=cortex-m33"], @@ -525,7 +525,7 @@ def hexagon(cpu_ver="v66", **kwargs): # LLVM target string def create_llvm_target(cpu_ver, config): - """ Create LLVM target string. """ + """Create LLVM target string.""" target = " -mtriple=hexagon" mcpu = " -mcpu=hexagon" + cpu_ver @@ -547,7 +547,7 @@ def create_target_features(config): # Simulator options string def create_sim_options(cpu_ver, config): - """ Create simulator option string. """ + """Create simulator option string.""" def validate_hvx_length(codegen_hvx, sim_options): if sim_options and "--hvx_length" in sim_options: @@ -606,7 +606,7 @@ def validate_hvx_length(codegen_hvx, sim_options): # LLVM options string def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create LLVM options string. """ + """Create LLVM options string.""" llvm_options = config["llvm_options"] @@ -620,7 +620,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument # TVM target attributes string def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument - """ Create TVM target features string. """ + """Create TVM target features string.""" features = { "link_params": "link-params", diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index a0b9b43735351..4c361bca6c57e 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -445,9 +445,12 @@ def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: import tvm from tvm import te + from tvm.te import create_prim_func + import tvm.script A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") + k = te.reduce_axis((0, 128), "k") C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") func = create_prim_func([A, B, C]) print(tvm.script.asscript(func)) diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index 95875acbd82c7..0413c44208b08 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -48,6 +48,7 @@ "metal": "mark a test as requiring metal", "llvm": "mark a test as requiring llvm", "ethosn": "mark a test as requiring ethosn", + "hexagon": "mark a test as requiring hexagon", } @@ -258,6 +259,8 @@ def _target_to_requirement(target): return utils.requires_opencl() if target.kind.name == "llvm": return utils.requires_llvm() + if target.kind.name == "hexagon": + return utils.requires_hexagon() return [] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 62531ff7c1942..39c759c7cd690 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -442,6 +442,7 @@ def _get_targets(target_str=None): "opencl -device=intel_graphics", "metal", "rocm", + "hexagon", ] @@ -818,6 +819,25 @@ def requires_ethosn(*args): return _compose(args, marks) +def requires_hexagon(*args): + """Mark a test as requiring Hexagon to run. + + Parameters + ---------- + f : function + Function to mark + """ + _requires_hexagon = [ + pytest.mark.hexagon, + pytest.mark.skipif(not device_enabled("hexagon"), reason="Hexagon support not enabled"), + *requires_llvm(), + pytest.mark.skipif( + tvm.target.codegen.llvm_version_major() < 7, reason="Hexagon requires LLVM 7 or later" + ), + ] + return _compose(args, _requires_hexagon) + + def requires_package(*packages): """Mark a test as requiring python packages to run. diff --git a/python/tvm/tir/generic.py b/python/tvm/tir/generic.py index 58efc09859708..68e995e01872c 100644 --- a/python/tvm/tir/generic.py +++ b/python/tvm/tir/generic.py @@ -121,7 +121,7 @@ def floordiv(lhs, rhs, span=None): Returns ------- op : tvm.Expr - The result Expr of divide operaton. + The result Expr of floordiv operaton. """ return _ffi_api._OpFloorDiv(lhs, rhs, span) # type: ignore @@ -139,6 +139,6 @@ def cast(src, dtype, span=None): Returns ------- op : tvm.Expr - The result Expr of divide operaton. + The result Expr of cast operaton. """ return _ffi_api._cast(dtype, src, span) # type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7545c09b020d7..d26ffc0b1efaa 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -373,6 +373,7 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: 1) The loops can't have annotations or thread bindings. 2) The (i+1)-th loop must be the only child of the i-th loop. 3) All loops must start with 0. + 4) The domain of a loop to be fused cannot depend on another loop to be fused. Parameters ---------- diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 8338208dd9686..bd8d7ec19bb35 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -25,7 +25,6 @@ from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda -from .conv2d_nhwc import schedule_conv2d_nhwc_direct @autotvm.register_topi_compute("conv2d_nchw.cuda") @@ -48,26 +47,6 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_nhwc.cuda") -def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): - """Compute conv2d with NHWC layout""" - return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) - - -@autotvm.register_topi_schedule("conv2d_nhwc.cuda") -def schedule_conv2d_nhwc(cfg, outs): - """Create the schedule for conv2d_nhwc""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "conv2d_nhwc": - schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - @autotvm.register_topi_compute("conv2d_cudnn.cuda") def conv2d_cudnn( cfg, data, kernel, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32" diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py index 6d9fd39e16b8e..8ed9362a3cf2a 100644 --- a/python/tvm/topi/gpu/__init__.py +++ b/python/tvm/topi/gpu/__init__.py @@ -18,3 +18,4 @@ # pylint: disable=redefined-builtin, wildcard-import """GPU specific declaration and schedules.""" from .dense import * +from .conv2d import * diff --git a/python/tvm/topi/gpu/conv2d.py b/python/tvm/topi/gpu/conv2d.py new file mode 100644 index 0000000000000..87c900e1d4d76 --- /dev/null +++ b/python/tvm/topi/gpu/conv2d.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Schedule for conv2d operator""" +from tvm import te, autotvm + +from .. import nn +from ..utils import traverse_inline +from .conv2d_nhwc import schedule_conv2d_nhwc_direct + + +@autotvm.register_topi_compute("conv2d_nhwc.gpu") +def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): + """Compute conv2d with NHWC layout""" + return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc.gpu") +def schedule_conv2d_nhwc(cfg, outs): + """Create the schedule for conv2d_nhwc""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv2d_nhwc": + schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/gpu/conv2d_nhwc.py similarity index 90% rename from python/tvm/topi/cuda/conv2d_nhwc.py rename to python/tvm/topi/gpu/conv2d_nhwc.py index e4361e30b5c3b..ff0610394eac8 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/gpu/conv2d_nhwc.py @@ -54,12 +54,13 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) + cfg.define_knob("vectorize", [1, 2, 4, 8]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.kind.name, target.model, "conv2d_nhwc.cuda" + target.kind.name, target.model, "conv2d_nhwc.gpu" ) cfg.fallback_with_reference_log(ref_log) @@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): vthread_n = cfg["vthread_n"].val vthread_c = cfg["vthread_c"].val step = cfg["step"].val + vec_factor = cfg["vectorize"].val block_factor_c = tile_c * num_thread_c * vthread_c offset = 8 @@ -85,15 +87,17 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy") # Schedule for output - ni, hi, wi, fi = s[output].op.axis - bz = s[output].fuse(hi, wi) + ni, _, wi, fi = s[output].op.axis + bx = wi + fi, vec = s[output].split(fi, factor=vec_factor) + s[output].vectorize(vec) tx, fi = s[output].split(fi, factor=tile_c) txz, tx = s[output].split(tx, factor=num_thread_c) - bx, txz = s[output].split(txz, factor=vthread_c) + bz, txz = s[output].split(txz, factor=vthread_c) ty, ni = s[output].split(ni, factor=tile_n) tyz, ty = s[output].split(ty, factor=num_thread_n) by, tyz = s[output].split(tyz, factor=vthread_n) - s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi) + s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi, vec) s[output].bind(bz, block_z) s[output].bind(by, block_y) s[output].bind(bx, block_x) @@ -106,6 +110,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): ni, yi, xi, fi = s[OL].op.axis ry, rx, rc = s[OL].op.reduce_axis rco, rci = s[OL].split(rc, factor=step) + s[OL].vectorize(fi) s[OL].reorder(rco, ry, rx, rci, ni, fi) s[AA].compute_at(s[OL], rx) @@ -125,6 +130,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): _, _, ic, o = s[WW].op.axis t = s[WW].fuse(ic, o) s[WW].storage_align(ic, W_align - 1, W_align) + t, vec = s[WW].split(t, factor=vec_factor) + s[WW].vectorize(vec) ty, tx = s[WW].split(t, factor=num_thread_c) _, ty = s[WW].split(ty, factor=num_thread_n) s[WW].bind(tx, thread_x) diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 876113b85f6e4..b9677d198eba0 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s } auto source = (*it).second; - DLOG(INFO) << "Source: " << std::endl << source->source; + VLOG(1) << "Source: " << std::endl << source->source; - DLOG(INFO) << "ReportAt " - << "span = " << span << " msg = " << diagnostic->message; + VLOG(1) << "ReportAt " + << "span = " << span << " msg = " << diagnostic->message; auto line_text = source.GetLine(span->line); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 4c37f0f1a6e9d..0e42b35349cdb 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -409,8 +409,9 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c const PassInfo& pass_info = Info(); ICHECK(mod.defined()) << "The input module must be set."; - DLOG(INFO) << "Executing module pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level; + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing module pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); mod = pass_func(std::move(mod), pass_ctx); @@ -422,6 +423,8 @@ IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) c pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; + VLOG(1) << "Result module:" << std::endl << PrettyPrint(mod); + return mod; } @@ -473,7 +476,10 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c for (const Pass& pass : passes) { ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!pass_ctx.PassEnabled(pass_info)) continue; + if (!pass_ctx.PassEnabled(pass_info)) { + VLOG(0) << "skipping disabled pass '" << pass_info->name << "'"; + continue; + } // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc new file mode 100644 index 0000000000000..104662b6aad0c --- /dev/null +++ b/src/meta_schedule/arg_info.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace meta_schedule { + +/******** ArgInfo ********/ + +ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { + // The JSON object is always an array whose first element is a tag. For example: + // `['TENSOR', 'float32', [1, 224, 224, 3]] + // Step 1. Extract the tag + String tag{runtime::ObjectPtr(nullptr)}; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() >= 1); + tag = Downcast(json_array->at(0)); + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + // Step 2. Dispatch the tag to corresponding subclass of ArgInfo + if (tag == "TENSOR") { + return TensorInfo::FromJSON(json_obj); + } + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj; + throw; +} + +Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { + using support::AsVector; + Array result; + result.reserve(func->params.size()); + for (const tir::Var& arg : func->params) { + if (Optional _buffer = func->buffer_map.Get(arg)) { + tir::Buffer buffer = _buffer.value(); + result.push_back(TensorInfo(/*dtype=*/buffer->dtype, + /*shape=*/AsVector(buffer->shape))); + } else { + LOG(FATAL) << "ValueError: Unsupported argument type: " << arg; + } + } + return result; +} + +/******** TensorInfo ********/ + +TensorInfo::TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape) { + ObjectPtr n = make_object(); + n->dtype = dtype; + n->shape = shape; + this->data_ = std::move(n); +} + +ObjectRef TensorInfoNode::AsJSON() const { + static String tag = "TENSOR"; + String dtype = DLDataType2String(this->dtype); + Array shape = support::AsArray(this->shape); + return Array{tag, dtype, shape}; +} + +TensorInfo TensorInfo::FromJSON(const ObjectRef& json_obj) { + DLDataType dtype; + Array shape; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 3); + // Load json[1] => dtype + { + String dtype_str = Downcast(json_array->at(1)); + dtype = runtime::String2DLDataType(dtype_str); + } + // Load json[2] => shape + shape = Downcast>(json_array->at(2)); + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TensorInfo(DataType(dtype), ShapeTuple(shape.begin(), shape.end())); +} + +/******** Repr ********/ + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + p->stream << "TensorInfo(\"" << self->dtype << "\", " << self->shape << ")"; + }); + +/******** FFI ********/ + +TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); +TVM_REGISTER_NODE_TYPE(TensorInfoNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoAsJSON").set_body_method(&ArgInfoNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromPrimFunc").set_body_typed(ArgInfo::FromPrimFunc); +TVM_REGISTER_GLOBAL("meta_schedule.ArgInfoFromJSON").set_body_typed(ArgInfo::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TensorInfo") + .set_body_typed([](runtime::DataType dtype, runtime::ShapeTuple shape) -> TensorInfo { + return TensorInfo(dtype, shape); + }); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc new file mode 100644 index 0000000000000..e67b3d1ab9b69 --- /dev/null +++ b/src/meta_schedule/database/database.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/******** Workload ********/ + +Workload::Workload(IRModule mod) { + ObjectPtr n = runtime::make_object(); + n->shash = tvm::StructuralHash()(mod); + n->mod = mod; + data_ = std::move(n); +} + +Workload::Workload(IRModule mod, Workload::THashCode shash) { + ObjectPtr n = runtime::make_object(); + n->mod = mod; + n->shash = shash; + data_ = std::move(n); +} + +ObjectRef WorkloadNode::AsJSON() const { + // Convert `this->mod` to JSON + std::string json_mod = tvm::SaveJSON(this->mod); + // Dump the JSON string to base64 + std::string b64_mod = Base64Encode(json_mod); + // Output + return Array{SHash2Str(this->shash), String(b64_mod)}; +} + +Workload Workload::FromJSON(const ObjectRef& json_obj) { + IRModule mod{nullptr}; + THashCode shash = 0; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 2); + // Load json[0] => shash + String str_shash = Downcast(json_array->at(0)); + // Load json[1] => mod + { + String b64_mod = Downcast(json_array->at(1)); + std::string json_mod = Base64Decode(b64_mod); + mod = Downcast(LoadJSON(json_mod)); + } + // Verify SHash(mod) == shash + shash = tvm::StructuralHash()(mod); + String recalc_shash = SHash2Str(shash); + CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash + << "; Recalculated: " << recalc_shash; + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return Workload(mod, shash); +} + +/******** TuningRecord ********/ + +TuningRecord::TuningRecord(tir::Trace trace, Array run_secs, Workload workload, + Target target, Array args_info) { + ObjectPtr n = make_object(); + n->trace = trace; + n->run_secs = run_secs; + n->workload = workload; + n->target = target; + n->args_info = args_info; + this->data_ = n; +} + +ObjectRef TuningRecordNode::AsJSON() const { + Array json_args_info; + json_args_info.reserve(args_info.size()); + for (const ArgInfo& arg_info : args_info) { + json_args_info.push_back(arg_info->AsJSON()); + } + return Array{trace->AsJSON(false), // + run_secs, // + target->Export(), // + json_args_info}; +} + +TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) { + tir::Trace trace{nullptr}; + Array run_secs{nullptr}; + Target target{nullptr}; + Array args_info; + try { + const ArrayNode* json_array = json_obj.as(); + CHECK(json_array && json_array->size() == 4); + // Load json[1] => run_secs + run_secs = Downcast>(json_array->at(1)); + // Load json[2] => target + target = Target(Downcast>(json_array->at(2))); + // Load json[3] => args_info + { + const ArrayNode* json_args_info = json_array->at(3).as(); + args_info.reserve(json_args_info->size()); + for (const ObjectRef& json_arg_info : *json_args_info) { + args_info.push_back(ArgInfo::FromJSON(json_arg_info)); + } + } + // Load json[0] => trace + { + const ObjectRef& json_trace = json_array->at(0); + tir::Schedule sch = + tir::Schedule::Traced(workload->mod, /*seed=*/-1, /*debug_mask=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + tir::Trace::ApplyJSONToSchedule(json_trace, sch); + trace = sch->trace().value(); + } + } catch (const std::runtime_error& e) { // includes tvm::Error and dmlc::Error + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + return TuningRecord(trace, run_secs, workload, target, args_info); +} + +/******** PyDatabase ********/ + +Database Database::PyDatabase(PyDatabaseNode::FCommitWorkload f_commit_workload, + PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record, + PyDatabaseNode::FGetTopK f_get_top_k, PyDatabaseNode::FSize f_size) { + ObjectPtr n = make_object(); + n->f_commit_workload = f_commit_workload; + n->f_commit_tuning_record = f_commit_tuning_record; + n->f_get_top_k = f_get_top_k; + n->f_size = f_size; + return Database(n); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(WorkloadNode); +TVM_REGISTER_NODE_TYPE(TuningRecordNode); +TVM_REGISTER_OBJECT_TYPE(DatabaseNode); +TVM_REGISTER_NODE_TYPE(PyDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.Workload").set_body_typed([](IRModule mod) { + return Workload(mod); +}); +TVM_REGISTER_GLOBAL("meta_schedule.WorkloadAsJSON") + .set_body_method(&WorkloadNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.WorkloadFromJSON").set_body_typed(&Workload::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecord") + .set_body_typed([](tir::Trace trace, Array run_secs, Workload workload, Target target, + Array args_info) { + return TuningRecord(trace, run_secs, workload, target, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordAsJSON") + .set_body_method(&TuningRecordNode::AsJSON); +TVM_REGISTER_GLOBAL("meta_schedule.TuningRecordFromJSON").set_body_typed(TuningRecord::FromJSON); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitWorkload") + .set_body_method(&DatabaseNode::CommitWorkload); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseCommitTuningRecord") + .set_body_method(&DatabaseNode::CommitTuningRecord); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseGetTopK") + .set_body_method(&DatabaseNode::GetTopK); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseSize").set_body_method(&DatabaseNode::Size); +TVM_REGISTER_GLOBAL("meta_schedule.DatabasePyDatabase").set_body_typed(Database::PyDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc new file mode 100644 index 0000000000000..3efb72e2fa745 --- /dev/null +++ b/src/meta_schedule/database/json_database.cc @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs); + double b_time = Mean(b->run_secs); + return a_time < b_time; + } +}; + +/*! \brief The default database implementation, which mimics two database tables with two files. */ +class JSONDatabaseNode : public DatabaseNode { + public: + /*! \brief The path to the workload table */ + String path_workload; + /*! \brief The path to the tuning record table */ + String path_tuning_record; + /*! \brief All the workloads in the database */ + std::unordered_map workloads2idx_; + /*! \brief All the tuning records in the database */ + std::multiset tuning_records_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("path_workload", &path_workload); + v->Visit("path_tuning_record", &path_tuning_record); + // `workloads2idx_` is not visited + // `tuning_records_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.JSONDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); + + public: + Workload CommitWorkload(const IRModule& mod) { + // Try to insert `mod` into `workloads_` + decltype(this->workloads2idx_)::iterator it; + bool inserted = false; + std::tie(it, inserted) = + this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1); + Workload workload = it->first; + // If `mod` is new in `workloads2idx_`, append it to the workload file + if (inserted) { + it->second = static_cast(this->workloads2idx_.size()) - 1; + JSONFileAppendLine(this->path_workload, JSONObj2Str(workload->AsJSON())); + } + return it->first; + } + + void CommitTuningRecord(const TuningRecord& record) { + this->tuning_records_.insert(record); + JSONFileAppendLine(this->path_tuning_record, + JSONObj2Str(Array{ + /*workload_index=*/Integer(this->workloads2idx_.at(record->workload)), + /*tuning_record=*/record->AsJSON() // + })); + } + + Array GetTopK(const Workload& workload, int top_k) { + CHECK_GE(top_k, 0) << "ValueError: top_k must be non-negative"; + if (top_k == 0) { + return {}; + } + Array results; + results.reserve(top_k); + int counter = 0; + for (const TuningRecord& record : this->tuning_records_) { + if (WorkloadEqual()(record->workload, workload)) { + results.push_back(record); + if (++counter == top_k) { + break; + } + } + } + return results; + } + + int64_t Size() { return tuning_records_.size(); } +}; + +Database Database::JSONDatabase(String path_workload, String path_tuning_record, + bool allow_missing) { + ObjectPtr n = make_object(); + // Load `n->workloads2idx_` from `path_workload` + std::vector workloads; + { + Array json_objs = JSONStr2Obj(JSONFileReadLines(path_workload, allow_missing)); + int n_objs = json_objs.size(); + n->workloads2idx_.reserve(n_objs); + workloads.reserve(n_objs); + for (int i = 0; i < n_objs; ++i) { + Workload workload = Workload::FromJSON(json_objs[i]); + n->workloads2idx_.emplace(workload, i); + workloads.push_back(workload); + } + } + // Load `n->tuning_records_` from `path_tuning_record` + { + Array json_objs = JSONStr2Obj(JSONFileReadLines(path_tuning_record, allow_missing)); + for (const ObjectRef& json_obj : json_objs) { + int workload_index = -1; + ObjectRef tuning_record{nullptr}; + try { + const ArrayNode* arr = json_obj.as(); + ICHECK_EQ(arr->size(), 2); + workload_index = Downcast(arr->at(0)); + tuning_record = arr->at(1); + } catch (std::runtime_error& e) { + LOG(FATAL) << "ValueError: Unable to parse the JSON object: " << json_obj + << "\nThe error is: " << e.what(); + } + n->tuning_records_.insert(TuningRecord::FromJSON(tuning_record, workloads[workload_index])); + } + } + n->path_workload = path_workload; + n->path_tuning_record = path_tuning_record; + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase").set_body_typed(Database::JSONDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc new file mode 100644 index 0000000000000..800a76f21e656 --- /dev/null +++ b/src/meta_schedule/runner/runner.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +RunnerInput::RunnerInput(String artifact_path, String device_type, Array args_info) { + ObjectPtr n = make_object(); + n->artifact_path = artifact_path; + n->device_type = device_type; + n->args_info = args_info; + this->data_ = n; +} + +RunnerResult::RunnerResult(Optional> run_secs, Optional error_msg) { + ObjectPtr n = make_object(); + n->run_secs = run_secs; + n->error_msg = error_msg; + this->data_ = n; +} + +RunnerFuture::RunnerFuture(RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) { + ObjectPtr n = make_object(); + n->f_done = f_done; + n->f_result = f_result; + this->data_ = n; +} + +Runner Runner::PyRunner(Runner::FRun f_run) { + ObjectPtr n = make_object(); + n->f_run = f_run; + return Runner(n); +} + +/******** FFI ********/ + +TVM_REGISTER_NODE_TYPE(RunnerInputNode); +TVM_REGISTER_NODE_TYPE(RunnerResultNode); +TVM_REGISTER_NODE_TYPE(RunnerFutureNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_NODE_TYPE(PyRunnerNode); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerInput") + .set_body_typed([](String artifact_path, String device_type, + Array args_info) -> RunnerInput { + return RunnerInput(artifact_path, device_type, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerResult") + .set_body_typed([](Array run_secs, Optional error_msg) -> RunnerResult { + return RunnerResult(run_secs, error_msg); + }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFuture") + .set_body_typed([](RunnerFuture::FDone f_done, RunnerFuture::FResult f_result) -> RunnerFuture { + return RunnerFuture(f_done, f_result); + }); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureDone") + .set_body_method(&RunnerFutureNode::Done); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerFutureResult") + .set_body_method(&RunnerFutureNode::Result); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerRun").set_body_method(&RunnerNode::Run); +TVM_REGISTER_GLOBAL("meta_schedule.RunnerPyRunner").set_body_typed(Runner::PyRunner); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc new file mode 100644 index 0000000000000..1c83aee8c0fd4 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using trace and random decisions. */ +class ReplayTraceNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayTraceNode* self; + /*! \brief The design spaces. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayTraceNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The number of threads to use. -1 means using logical cpu number. */ + int num_threads_ = -1; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this, design_spaces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayTraceNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + ICHECK_LT(st, ed); + std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); + Array per_task_result(ed - st, MeasureCandidate{nullptr}); + auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, + int task_id) -> void { + TRandState& rand_state = per_thread_rand_state[thread_id]; + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]->trace().value(); + tir::Trace new_trace = tir::Trace(trace->insts, {}); + tir::Schedule sch = tir::Schedule::Traced( // + self->mod_, // + /*rand_state=*/ForkSeed(&rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + }; + support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); + return per_task_result; +} + +inline void ReplayTraceNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayTraceNode); +TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc new file mode 100644 index 0000000000000..fefe8dfce76e9 --- /dev/null +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCandidate::MeasureCandidate(tir::Schedule sch, Array args_info) { + ObjectPtr n = make_object(); + n->sch = sch; + n->args_info = args_info; + data_ = std::move(n); +} + +SearchStrategy SearchStrategy::PySearchStrategy( + PySearchStrategyNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PySearchStrategyNode::FPreTuning f_pre_tuning, // + PySearchStrategyNode::FPostTuning f_post_tuning, // + PySearchStrategyNode::FGenerateMeasureCandidates f_generate_measure_candidates, // + PySearchStrategyNode::FNotifyRunnerResults f_notify_runner_results) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = f_initialize_with_tune_context; + n->f_pre_tuning = f_pre_tuning; + n->f_post_tuning = f_post_tuning; + n->f_generate_measure_candidates = f_generate_measure_candidates; + n->f_notify_runner_results = f_notify_runner_results; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); +TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); +TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCandidate") + .set_body_typed([](tir::Schedule sch, Array args_info) -> MeasureCandidate { + return MeasureCandidate(sch, args_info); + }); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPySearchStrategy") + .set_body_typed(SearchStrategy::PySearchStrategy); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyInitializeWithTuneContext") + .set_body_method(&SearchStrategyNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPreTuning") + .set_body_method(&SearchStrategyNode::PreTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyPostTuning") + .set_body_method(&SearchStrategyNode::PostTuning); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyGenerateMeasureCandidates") + .set_body_method(&SearchStrategyNode::GenerateMeasureCandidates); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyNotifyRunnerResults") + .set_body_method(&SearchStrategyNode::NotifyRunnerResults); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc new file mode 100644 index 0000000000000..6df8da2f7aa12 --- /dev/null +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +SpaceGenerator SpaceGenerator::PySpaceGenerator( + PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, + PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_generate_design_space = std::move(f_generate_design_space); + return SpaceGenerator(n); +} + +TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); +TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") + .set_body_method(&SpaceGeneratorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") + .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") + .set_body_typed(SpaceGenerator::PySpaceGenerator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc new file mode 100644 index 0000000000000..9c2e3eeabe099 --- /dev/null +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The union of design space generators. */ +class SpaceGeneratorUnionNode : public SpaceGeneratorNode { + public: + /*! \brief The array of design space generators unioned, could be recursive. */ + Array space_generators; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("space_generators", &space_generators); } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + // Initialize each space generator. + for (const SpaceGenerator& space_generator : space_generators) { + space_generator->InitializeWithTuneContext(tune_context); + } + } + + Array GenerateDesignSpace(const IRModule& mod) final { + Array design_spaces; + for (const SpaceGenerator& space_generator : space_generators) { + // Generate partial design spaces from each design space generator. + Array partial = space_generator->GenerateDesignSpace(mod); + // Merge the partial design spaces. + design_spaces.insert(design_spaces.end(), partial.begin(), partial.end()); + } + return design_spaces; + } + + static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; + TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode); +}; + +/*! + * \brief Create a design space generator as union of given design space generators. + * \param space_generators Array of the design space generators to be unioned. + * \return The design space generator created. + */ +SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators) { + ObjectPtr n = make_object(); + n->space_generators = std::move(space_generators); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") + .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc new file mode 100644 index 0000000000000..ad82b6f514a22 --- /dev/null +++ b/src/meta_schedule/tune_context.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "./utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Constructor function of TuneContext class. + * \param mod The mod to be optimized. + * \param target The target to be optimized for. + * \param space_generator The design space generator. + * \param task_name The name of the tuning task. + * \param rand_state The random state. + * \param num_threads The number of threads to be used. + * \param verbose The verbosity level. + */ +TuneContext::TuneContext(Optional mod, // + Optional target, // + Optional space_generator, // + Optional task_name, // + support::LinearCongruentialEngine::TRandState rand_state, // + int num_threads) { + ObjectPtr n = make_object(); + n->mod = mod; + n->target = target; + n->space_generator = space_generator; + n->task_name = task_name; + if (rand_state == -1) { + rand_state = std::random_device()(); + } + support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); + n->num_threads = num_threads; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TuneContextNode); + +TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") + .set_body_typed([](Optional mod, // + Optional target, // + Optional space_generator, // + Optional task_name, // + support::LinearCongruentialEngine::TRandState rand_state, // + int num_threads) -> TuneContext { + return TuneContext(mod, target, space_generator, task_name, rand_state, num_threads); + }); +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 47331203a25af..30294b8f91e13 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -19,10 +19,195 @@ #ifndef TVM_META_SCHEDULE_UTILS_H_ #define TVM_META_SCHEDULE_UTILS_H_ +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../printer/text_printer.h" +#include "../support/array.h" +#include "../support/base64.h" +#include "../tir/schedule/primitive.h" namespace tvm { -namespace meta_schedule {} // namespace meta_schedule +namespace meta_schedule { + +/*! + * \brief Read lines from a json file. + * \param path The path to the json file. + * \param allow_missing Whether to create new file when the given path is not found. + * \return An array containing lines read from the json file. + */ +inline Array JSONFileReadLines(const String& path, bool allow_missing) { + std::ifstream is(path); + if (is.good()) { + Array results; + for (std::string str; std::getline(is, str);) { + results.push_back(str); + } + return results; + } + CHECK(allow_missing) << "ValueError: File doesn't exist: " << path; + std::ofstream os(path); + CHECK(os.good()) << "ValueError: Cannot create new file: " << path; + return {}; +} + +/*! + * \brief Append a line to a json file. + * \param path The path to the json file. + * \param line The line to append. + */ +inline void JSONFileAppendLine(const String& path, const std::string& line) { + std::ofstream os(path, std::ofstream::app); + CHECK(os.good()) << "ValueError: Cannot open the file to write: " << path; + os << line << std::endl; +} + +/*! + * \brief Get the base64 encoded result of a string. + * \param str The string to encode. + * \return The base64 encoded string. + */ +inline std::string Base64Encode(std::string str) { + std::string result; + dmlc::MemoryStringStream m_stream(&result); + support::Base64OutStream b64stream(&m_stream); + static_cast(&b64stream)->Write(str); + b64stream.Finish(); + return result; +} + +/*! + * \brief Get the base64 decoded result of a string. + * \param str The string to decode. + * \return The base64 decoded string. + */ +inline std::string Base64Decode(std::string str) { + std::string result; + dmlc::MemoryStringStream m_stream(&str); + support::Base64InStream b64stream(&m_stream); + b64stream.InitPosition(); + static_cast(&b64stream)->Read(&result); + return result; +} + +/*! + * \brief Parse lines of json string into a json object. + * \param lines The lines of json string. + * \return Array of json objects parsed. + * \note The function calls the python-side json parser in runtime registry. + */ +inline Array JSONStr2Obj(const Array& lines) { + static const runtime::PackedFunc* f_to_obj = + runtime::Registry::Get("meta_schedule.batch_json_str2obj"); + ICHECK(f_to_obj) << "IndexError: Cannot find the packed function " + "`meta_schedule.batch_json_str2obj` in the global registry"; + return (*f_to_obj)(lines); +} + +/*! + * \brief Serialize a json object into a json string. + * \param json_obj The json object to serialize. + * \return A string containing the serialized json object. + * \note The function calls the python-side json obj serializer in runtime registry. + */ +inline String JSONObj2Str(const ObjectRef& json_obj) { + static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("meta_schedule.json_obj2str"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`meta_schedule.json_obj2str` in the global registry"; + return (*f_to_str)(json_obj); +} + +/*! + * \brief Converts a structural hash code to string + * \param hash_code The hash code + * \return The string representation of the hash code + */ +inline String SHash2Str(Workload::THashCode hash_code) { return std::to_string(hash_code); } + +/*! + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * \param mod The IRModule to find the entry function. + * \return The entry function. + */ +inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + return GetRef(func); + } + if (gv->name_hint == "main") { + main_func = func; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + return GetRef(main_func); + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 0) { + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " + << tir::AsTVMScript(mod); + } + if (num_prim_func > 1) { + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << tir::AsTVMScript(mod); + } + return GetRef(last_func); +} + +/*! + * \brief Fork a random state into another, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \return The forked random state + */ +inline support::LinearCongruentialEngine::TRandState ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state) { + return support::LinearCongruentialEngine(rand_state).ForkSeed(); +} + +/*! + * \brief Fork a random state into another ones, i.e. PRNG splitting. + * The given random state is also mutated. + * \param rand_state The random state to be forked + * \param n The number of forks + * \return The forked random states + */ +inline std::vector ForkSeed( + support::LinearCongruentialEngine::TRandState* rand_state, int n) { + std::vector results; + results.reserve(n); + for (int i = 0; i < n; ++i) { + results.push_back(support::LinearCongruentialEngine(rand_state).ForkSeed()); + } + return results; +} + +} // namespace meta_schedule } // namespace tvm #endif // TVM_META_SCHEDULE_UTILS_H_ diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 1fa72c92b6fc1..8e52af60d2351 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,6 +19,7 @@ /*! * \file src/node/structural_equal.cc */ +#include #include #include #include @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl + << PrettyPrint(lhs) << std::endl + << "and rhs:" << std::endl + << PrettyPrint(rhs); } return result; } diff --git a/src/parser/meta_ref.h b/src/parser/meta_ref.h index 481f334cb0fe0..483b7f726e073 100644 --- a/src/parser/meta_ref.h +++ b/src/parser/meta_ref.h @@ -26,6 +26,7 @@ #define TVM_PARSER_META_REF_H_ #include +#include #include #include @@ -36,8 +37,6 @@ namespace parser { using namespace relay; -using MetaTable = Map>; - /*! * \brief Options for allocating storage. */ diff --git a/src/parser/parser.cc b/src/parser/parser.cc index c6407e8909d93..5eec716cc20c9 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -417,7 +417,7 @@ class Parser { * Useful for matching optional tokens, effectively looksahead by one. */ bool WhenMatch(const TokenType& token_type) { - DLOG(INFO) << "Parser::WhenMatch: Peek() == " << Peek(); + VLOG(1) << "Parser::WhenMatch: Peek() == " << Peek(); if (Peek()->token_type == token_type) { Consume(token_type); return true; @@ -594,7 +594,7 @@ class Parser { template R WithSpan(std::function parser) { auto start_span = Peek()->span; - DLOG(INFO) << "WithSpan: start_span = " << start_span; + VLOG(0) << "WithSpan: start_span = " << start_span; R ast = parser(); if (ast.defined()) { // The token at the head of the stream is now 1 past where we parsed. So we find its start @@ -608,7 +608,7 @@ class Parser { span_pos--; } auto end_token = tokens.at(span_pos); - DLOG(INFO) << "WithSpan: end_span = " << end_token->span; + VLOG(0) << "WithSpan: end_span = " << end_token->span; ast->span = start_span.Merge(end_token->span); } return ast; @@ -668,8 +668,8 @@ class Parser { template Array ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function parse, std::function before_stop = nullptr) { - DLOG(INFO) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) - << " stop=" << ToString(stop); + VLOG(0) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep) + << " stop=" << ToString(stop); Match(start); // This is for the empty arguments list case, if we have token stream @@ -686,7 +686,7 @@ class Parser { if (WhenMatch(stop)) { return Array(); } else { - DLOG(INFO) << "Parser::ParseSequence: parse first"; + VLOG(0) << "Parser::ParseSequence: parse first"; auto data = parse(); Array elements = {data}; @@ -695,7 +695,7 @@ class Parser { // parse '( expr ',' * ')' } else if (WhenMatch(sep)) { while (true) { - DLOG(INFO) << "Parser::ParseSequence: parse element"; + VLOG(0) << "Parser::ParseSequence: parse element"; if (WhenMatch(stop)) { break; } else { @@ -893,12 +893,12 @@ class Parser { /*! \brief Parse a single Relay expression. */ Expr ParseExpr() { - DLOG(INFO) << "Parser::ParseExpr"; + VLOG(0) << "Parser::ParseExpr"; return WithSpan([this] { std::vector exprs; while (true) { - DLOG(INFO) << "Parser::ParseExpr: parsing a single expression"; + VLOG(0) << "Parser::ParseExpr: parsing a single expression"; auto next = Peek(); switch (next->token_type) { // For graph or let, match first rhs, then invoke ParseBindingExpr @@ -1011,7 +1011,7 @@ class Parser { // This ensures for n sequential bindings // the call depth will be the same before // and after parsing the n bindings. - DLOG(INFO) << "Parser::ParseBindingExpr"; + VLOG(0) << "Parser::ParseBindingExpr"; std::vector> bindings; int scopes = 0; @@ -1085,15 +1085,13 @@ class Parser { * Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }. */ Function ParseFunctionDef() { - DLOG(INFO) << "Parser::ParseFunctionDef"; + VLOG(0) << "Parser::ParseFunctionDef"; return WithSpan([&]() { PushScope(); PushTypeScope(); Array generics; if (Peek()->token_type == TokenType::kLSquare) { - // If we have generics we need to add a type scope. - PushTypeScope(); generics = ParseSequence( TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() { auto type_var_name = Match(TokenType::kIdentifier).ToString(); @@ -1149,7 +1147,7 @@ class Parser { /*! \brief Parse an if-expression. */ Expr ParseIf() { return WithSpan([&]() { - DLOG(INFO) << "Parser::ParseIf"; + VLOG(0) << "Parser::ParseIf"; Consume(TokenType::kIf); auto guard = WithSpan([&] { return Parens([&] { return ParseExpr(); }); }); @@ -1188,7 +1186,7 @@ class Parser { * This function recursively parses a pattern. */ Pattern ParsePattern() { - DLOG(INFO) << "Parser::ParsePattern"; + VLOG(0) << "Parser::ParsePattern"; auto next = Peek(); switch (next->token_type) { case TokenType::kUnderscore: { @@ -1251,7 +1249,7 @@ class Parser { } Expr ParseExprBinOp() { - DLOG(INFO) << "Parser::ParseExprBinOp"; + VLOG(0) << "Parser::ParseExprBinOp"; return WithSpan([this] { // We must parse at least one expression, the default // case is that there is no operator and we will fall @@ -1335,7 +1333,7 @@ class Parser { } ObjectRef ParseAttributeValue() { - DLOG(INFO) << "Parser::ParseAttributeValue"; + VLOG(0) << "Parser::ParseAttributeValue"; auto next = Peek(); switch (next->token_type) { case TokenType::kFloat: @@ -1377,7 +1375,7 @@ class Parser { } Map ParseAttrs() { - DLOG(INFO) << "Parser::ParseAttrs"; + VLOG(0) << "Parser::ParseAttrs"; Map kwargs; while (Peek()->token_type == TokenType::kIdentifier) { auto key = GetHierarchicalName(ParseHierarchicalName().data); @@ -1389,14 +1387,14 @@ class Parser { kwargs.Set(key, value); WhenMatch(TokenType::kComma); } - DLOG(INFO) << "Parser::ParseAttrs: kwargs=" << kwargs; + VLOG(0) << "Parser::ParseAttrs: kwargs=" << kwargs; return kwargs; } Expr ParseCallArgs(Expr op) { ICHECK(op.defined()) << "the operator must be defined"; - DLOG(INFO) << "Parser::ParseCallArgs"; + VLOG(0) << "Parser::ParseCallArgs"; Attrs attrs; std::string op_key; bool is_op = false; @@ -1444,6 +1442,10 @@ class Parser { ICHECK(attr_obj.defined()); attrs = Downcast(attr_obj); } + } else { + this->diag_ctx.EmitFatal(Diagnostic::Error(op->span) + << "unable to determine the 'attrs_type_key' with which " + "to represent the call attributes for this operator"); } } return true; @@ -1469,7 +1471,7 @@ class Parser { } Expr ParseCallExpr() { - DLOG(INFO) << "Parser::ParseCallExpr"; + VLOG(0) << "Parser::ParseCallExpr"; return WithSpan([this] { Expr expr = ParseAtomicExpr(); // Parse as many call args as possible, building up expression @@ -1498,7 +1500,7 @@ class Parser { } Expr GetOp(const std::string& op_name, const Span& span) { - DLOG(INFO) << "op_name=" << op_name << " span=" << span; + VLOG(0) << "op_name=" << op_name << " span=" << span; try { return Op::Get(op_name); } catch (const Error& e) { @@ -1511,7 +1513,7 @@ class Parser { } Expr ParseAtomicExpr() { - DLOG(INFO) << "Parser::ParseAtomicExpr"; + VLOG(0) << "Parser::ParseAtomicExpr"; Expr expr = WithSpan([this] { auto next = Peek(); switch (next->token_type) { @@ -1647,7 +1649,7 @@ class Parser { auto token = Match(TokenType::kInteger); auto index = token.ToNumber(); auto span = token->span.Merge(expr->span); - DLOG(INFO) << "Parser::ParseAtomicExpr: tuple get item"; + VLOG(0) << "Parser::ParseAtomicExpr: tuple get item"; return relay::TupleGetItem(expr, index, span); } else { return expr; @@ -1867,9 +1869,8 @@ class Parser { }; Parser InitParser(const std::string& file_name, const std::string& file_content, - Optional init_module) { - DLOG(INFO) << "InitParser: file_name: " << file_name - << "file_content_size: " << file_content.size(); + const Optional& init_module, const MetaTable& init_meta_table) { + VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size(); SourceName src_name = SourceName::Get(file_name); Source source(src_name, file_content); @@ -1887,19 +1888,33 @@ Parser InitParser(const std::string& file_name, const std::string& file_content, auto tokens_and_table = Tokenize(diag_ctx, source); auto tokens = tokens_and_table.first; - auto meta_data_table = tokens_and_table.second; + MetaTable meta_data_table = tokens_and_table.second.ToMetadata(); + + // Merge any entries in init_meta_table into anything captured in the #[metadata] section + // of the file_content. Metadata references within file_content must use indexes which account + // for this ordering. + for (const auto& pair : init_meta_table) { + Array items; + if (meta_data_table.count(pair.first)) { + items = meta_data_table[pair.first]; + } + for (const auto& obj : pair.second) { + items.push_back(obj); + } + meta_data_table.Set(pair.first, items); + } - return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata()); + return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table)); } -IRModule ParseModule(std::string file_name, std::string file_content, - Optional init_module) { - DLOG(INFO) << "ParseModule"; - auto parser = InitParser(file_name, file_content, init_module); +IRModule ParseModule(const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + VLOG(0) << "ParseModule"; + auto parser = InitParser(file_name, file_content, init_module, init_meta_table); auto mod = parser.ParseModule(); ICHECK(mod.defined()) << "The parser must return a non-null module."; - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); auto infer_type = tvm::relay::transform::InferType(); @@ -1907,22 +1922,28 @@ IRModule ParseModule(std::string file_name, std::string file_content, return infer_type(mod); } -Expr ParseExpr(std::string file_name, std::string file_content) { - DLOG(INFO) << "ParseExpr"; - auto parser = InitParser(file_name, file_content, Optional()); +Expr ParseExpr(const std::string& file_name, const std::string& file_content) { + VLOG(0) << "ParseExpr"; + auto parser = InitParser(file_name, file_content, Optional(), MetaTable()); parser.ParseSemVer(false); parser.PushScope(); auto expr = parser.ParseExpr(); parser.Match(TokenType::kEndOfFile); - // NB(@jroesch): it is very important that we render any errors before we procede - // if there were any errors which allow the parser to procede we must render them + // NB(@jroesch): it is very important that we render any errors before we proceed + // if there were any errors which allow the parser to proceed we must render them // here. parser.diag_ctx.Render(); return expr; } +TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") + .set_body_typed([](const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + return ParseModule(file_name, file_content, init_module, init_meta_table); + }); + TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](tvm::String file_name, tvm::String file_content) { + .set_body_typed([](const std::string& file_name, const std::string& file_content) { return ParseModule(file_name, file_content); }); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 4e79d0e74c592..3c1329670c40e 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) { } tvm::String Source::GetLine(int line) { - DLOG(INFO) << "Source::GetLine: line=" << line; + VLOG(1) << "Source::GetLine: line=" << line; ICHECK(line - 1 < static_cast((*this)->line_map.size())) << "requested line: " << line << "at index: " << (line - 1) << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) { auto range = (*this)->line_map.at(line - 1); int line_start = range.first; int line_length = range.second; - DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; // TODO(@jroesch): expose substring on tvm::String. auto line_text = std::string((*this)->source).substr(line_start, line_length); - DLOG(INFO) << "Source::GetLine: line_text=" << line_text; + VLOG(1) << "Source::GetLine: line_text=" << line_text; return line_text; } diff --git a/src/parser/tokenizer.h b/src/parser/tokenizer.h index d7b3b3f6f681f..8f197db45318d 100644 --- a/src/parser/tokenizer.h +++ b/src/parser/tokenizer.h @@ -339,7 +339,7 @@ struct Tokenizer { int line = this->line; int col = this->col; auto next = Peek(); - DLOG(INFO) << "tvm::parser::TokenizeOnce: next=" << next; + VLOG(1) << "tvm::parser::TokenizeOnce: next=" << next; if (next == '\n') { auto token = NewToken(TokenType::kNewline); Next(); @@ -550,7 +550,7 @@ struct Tokenizer { } void Tokenize() { - DLOG(INFO) << "tvm::parser::Tokenize"; + VLOG(0) << "tvm::parser::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); ICHECK(token.defined()); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index aad42fc9b0ea5..ea97bb35a09f2 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -34,6 +34,7 @@ */ #include #include +#include #include #include #include @@ -119,6 +120,9 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); + } else if (!show_meta_data_ && node.as()) { + // Show attributes in readable form. + return PrintAttrs(Downcast(node)); } else { // default module. std::ostringstream os; @@ -769,6 +773,8 @@ Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { printed_attr << "?"; } else if (auto str_obj = value.as()) { printed_attr << Doc::StrLiteral(GetRef(str_obj)); + } else if (const auto* on_device_attrs = value.as()) { + printed_attr << "device_type=" << on_device_attrs->device_type; } else if (meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); } else { @@ -848,17 +854,28 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { RelayTextPrinter* parent_; }; +Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) { + std::vector docs; + AttrPrinter printer(&docs, this); + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + + return doc; +} + std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); - if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { + if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) { // fallback Doc doc; doc << meta_->GetMetaNode(attrs); docs.push_back(doc); return docs; } else { + // Show attributes in readable form. AttrPrinter printer(&docs, this); const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); if (!op_node) { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 1e882db1fd61c..b8533a5d88011 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -67,7 +67,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { String PrettyPrint(const ObjectRef& node) { Doc doc; - doc << TextPrinter(false, nullptr, false).PrintFinal(node); + doc << TextPrinter(/*show_meta_data=*/false, nullptr, false).PrintFinal(node); return doc.str(); } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 0332a2d539d29..7e4a56529ddc7 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -77,6 +77,7 @@ class RelayTextPrinter : public ExprFunctor, // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const ObjectRef& node); Doc PrintFinal(const ObjectRef& node); + Doc PrintAttrs(const Attrs& attrs); std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); std::vector PrintFuncAttrs(const Attrs& attrs); Doc PrintSpan(const Span& span); diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 8fd0f87239ff5..f150453ba0b66 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -40,6 +40,8 @@ class OperatorExtractorWrapper : private MixedModeVisitor { } private: + using MixedModeVisitor::VisitExpr_; + const IRModule mod_; /*! \brief Map of operator to frequency. */ Map operator_freqs_; diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index ad9ba1b2069da..fc850e37379c6 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -291,7 +291,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { args.push_back(param_handle); } else { auto var_arg = FindExpr(arg); - args.push_back(var_arg[0]); + for (const auto& var : var_arg) { + args.push_back(var); + } } } @@ -623,8 +625,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { + // The buffer_var is created with storage_scope to be global.workspace to be serviced by + // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor + // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and + // should not be lowered to the stack. For more details please refer to the discussion here: + // https://github.com/apache/tvm/issues/9022 te::Var buffer_var(MakeString("sid_", sid), - PointerType(PrimType(DataType::Int(8)), "global")); + PointerType(PrimType(DataType::Int(8)), "global.workspace")); sids_table_[sid] = buffer_var; } } @@ -658,7 +665,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Apply storage rewrite pass to the runner function to do memory planning auto storage_rewrite = tir::transform::StorageRewrite(); mod_run = storage_rewrite(mod_run); - // The workspace for main function should be calculated after performing storage_rewrite for // the top level TIR function. auto workspace_byte_alignment = diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6142e8323dea1..0e7af22783750 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -19,7 +19,7 @@ /*! * \file relay/backend/compile_engine.cc - * \brief Internal compialtion engine. + * \brief Internal compilation engine. */ #include "compile_engine.h" diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e96255e976e93..ae58c2f08e8cf 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -54,6 +54,15 @@ inline size_t GetShape1DSize(const Type& type) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } +inline std::string GetShapeString(std::vector shape) { + std::string v = "std::vector{"; + for (auto s : shape) { + v += std::to_string(s) + ","; + } + v += "}"; + return v; +} + std::vector Conv2d(const CallNode* call) { std::vector args; const auto* conv2d_attr = call->attrs.as(); @@ -67,11 +76,13 @@ std::vector Conv2d(const CallNode* call) { args.push_back(std::to_string(s)); } - // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw + // Args: O, G, Ph0, Pw0, Ph1, Pw1, Kh, Kw, Sh, Sw args.push_back(std::to_string(wshape[0])); args.push_back(std::to_string(conv2d_attr->groups)); args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); + args.push_back(std::to_string(conv2d_attr->padding[2].as()->value)); + args.push_back(std::to_string(conv2d_attr->padding[3].as()->value)); args.push_back(std::to_string(wshape[2])); args.push_back(std::to_string(wshape[3])); args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); @@ -96,12 +107,8 @@ std::vector Dense(const CallNode* call) { std::vector Relu(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - + args.push_back(GetShapeString(ishape)); return args; } @@ -121,15 +128,25 @@ std::vector BatchNorm(const CallNode* call) { return args; } +// should comply with src/runtime/contrib/dnnl/dnnl.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + std::vector Add(const CallNode* call) { std::vector args; auto ishape = GetShape(call->args[0]->checked_type()); - + args.push_back(std::to_string(DNNL_BINARY_ADD)); // Args: H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } + args.push_back(GetShapeString(ishape)); + return args; +} +std::vector Multiply(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + args.push_back(std::to_string(DNNL_BINARY_MUL)); + // Args: H, W + args.push_back(GetShapeString(ishape)); return args; } @@ -237,11 +254,9 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C using ArgFunType = std::function(const CallNode*)>; static const std::map> op_map = { - {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, - {"nn.dense", {"dnnl_dense", Dense}}, - {"nn.relu", {"dnnl_relu", Relu}}, - {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, - {"add", {"dnnl_add", Add}}, + {"nn.conv2d", {"dnnl_conv2d", Conv2d}}, {"nn.dense", {"dnnl_dense", Dense}}, + {"nn.relu", {"dnnl_relu", Relu}}, {"nn.batch_norm", {"dnnl_bn", BatchNorm}}, + {"add", {"dnnl_binary_op", Add}}, {"multiply", {"dnnl_binary_op", Multiply}}, }; const auto op_name = GetRef(op_node)->name; diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc new file mode 100644 index 0000000000000..61a880e17ffba --- /dev/null +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../../../runtime/file_utils.h" + +namespace tvm { +namespace runtime { + +class EthosUModuleNode : public ModuleNode { + public: + /*! + * \brief The ethos runtime module. + * + * \param cmms A array of external symbol 1, serialized command stream 1 + * external symbol 2, serialized command stream 2, .... + * TODO : if and when FFI support Maps with non-objects OR compound arrays + * switch to that. + */ + explicit EthosUModuleNode(const String& func_name_, const String& cmms_hex_, + const String& weights_bias_hex_, const Integer& scratch_size_, + const Integer& input_size_, const Integer& output_size_) { + func_names_.push_back(func_name_); + cmms_hex = std::move(cmms_hex_); + weights_bias_hex = std::move(weights_bias_hex_); + scratch_size = scratch_size_->value; + input_size = input_size_->value; + output_size = output_size_->value; + c_source = GenerateSource(); + } + + /*! + * \brief Save the module to file. + * + * \param file_name The file to be saved to. + * \param format The format of the file. + */ + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + LOG(INFO) << "format=" << fmt << ";;\n"; + ICHECK_EQ(fmt, "c") << "Can only save to format=" + << "c"; + std::ofstream out(file_name); + out << c_source; + out.close(); + } + + std::string GetSource(const std::string& format) final { return c_source; } + + std::string GetCS() { return cmms_hex; } + + /*! + * \brief Get a PackedFunc from the module. + * + * \param name The name of the function. + * \param sptr_to_self The ObjectPtr that points to this module node. + * + * \return The function pointer when it is found, otherwise, PackedFunc(nullptr). + */ + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); + } + return PackedFunc(); + } + + const char* type_key() const override { return "c"; } + + static Module Create(String func_name, String cmms_hex, String weights_bias_hex, + Integer scratch_size, Integer input_size, Integer output_size) { + auto n = make_object(func_name, cmms_hex, weights_bias_hex, scratch_size, + input_size, output_size); + return Module(n); + } + + private: + String c_source; + Array func_names_; + String cmms_hex; + String weights_bias_hex; + size_t scratch_size; + size_t input_size; + size_t output_size; + int indent_{0}; + + /*! + * \brief Convert the raw string of hex values into a hex string + * + * \param raw the raw string of hex values + * + * \return string formatted as a hex string + */ + std::string GetHexString(const std::string& raw) { + std::stringstream ss; + for (size_t i = 0; i < raw.size() / 2; ++i) { + ss << "\\x" << raw.substr(i * 2, 2); + } + return ss.str(); + } + + /*! + * \brief Emit code that updates the base_addrs array with the base address of the given array + * + * \param index array index for base_addrs and base_addrs_size + * \param name of the array containing relevant data + * + * \return string of code that updates the base_addrs array with the base address of the given + * array + */ + std::string SetBaseAddress(int index, std::string name) { + std::stringstream ss; + ss << " base_addrs[" << index << "] = (uintptr_t)(" << name << ");\n"; + ss << " base_addrs_size[" << index << "] = " << name << "_size;\n"; + return ss.str(); + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + ICHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! \brief Print indents using spaces. */ + void PrintIndents(std::stringstream& ss) { + for (int i = 0; i < indent_; i++) { + ss << ' '; + } + } + + /*! + * \brief Creates a runtime function header + */ + void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string func_name) { + ss << "TVM_DLL int32_t "; + ss << func_name << "(void* input, void* output) {\n"; + } + + /*! + * \brief Creates a cplusplus guard prefix for extern "C" printing + */ + void PrintExternCPrefix(std::stringstream& ss) { + PrintIndents(ss); + ss << "#ifdef __cplusplus\n"; + ss << "extern \"C\" {\n"; + ss << "#endif\n"; + } + + /*! + * \brief Creates a cplusplus guard postfix for extern "C" printing + */ + void PrintExternCPostfix(std::stringstream& ss) { + PrintIndents(ss); + ss << "#ifdef __cplusplus\n"; + ss << "}\n"; + ss << "#endif\n"; + } + + /*! + * \brief Emit code that offloads a subgraph to the NPU + * + * \return string of code that offloads a subgraph to the NPU + */ + std::string GenerateSource() { + std::string func_no_dashes = func_names_[0]; + std::replace(func_no_dashes.begin(), func_no_dashes.end(), '-', '_'); + std::stringstream ss; + + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "#include \n"; + ss << "\n"; + size_t weights_size = (weights_bias_hex.size() / 2); + ss << "static const size_t weights_size = " << std::to_string(weights_size) << ";\n"; + ss << "static const size_t scratch_size = " << std::to_string(scratch_size) << ";\n"; + ss << "// Update linker script to place ethosu_scratch in memory that can be accessed by the " + "NPU\n"; + if (weights_size > 0) { + ss << "__attribute__((section(\"ethosu_scratch\"), aligned(16))) static int8_t weights[" + << weights_size << "] = \""; + ss << GetHexString(weights_bias_hex); + ss << "\";\n"; + } else { + ss << "static int8_t* weights = NULL;\n"; + } + ss << "__attribute__((section(\"ethosu_scratch\"), aligned(16))) static int8_t cms_data_data[" + << cmms_hex.size() / 2 << "] = \""; + ss << GetHexString(cmms_hex); + ss << "\";\n"; + ss << "static const size_t cms_data_size = sizeof(cms_data_data);\n"; + ss << "\n"; + + PrintExternCPrefix(ss); + ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, " + << "size_t in0_size, int8_t* out0, size_t out0_size) {\n"; + ss << " int num_tensors = 5;\n"; + ss << " void* cms_data = (void*)(cms_data_data);\n"; + ss << " int64_t device_type = kDLCPU;\n"; + ss << " int64_t device_id = 0;\n"; + if (scratch_size > 0) { + ss << " int8_t* scratch = (int8_t*) TVMBackendAllocWorkspace(device_type, device_id, " + "(uint64_t)scratch_size, 0, 16);\n"; + } else { + ss << " int8_t* scratch = NULL;\n"; + } + ss << " size_t base_addrs_size[num_tensors];\n"; + ss << " uint64_t base_addrs[num_tensors];\n"; + ss << "\n"; + ss << SetBaseAddress(0, "weights"); + ss << SetBaseAddress(1, "scratch"); + ss << SetBaseAddress(2, "scratch"); + ss << SetBaseAddress(3, "in0"); + ss << SetBaseAddress(4, "out0"); + ss << "\n"; + ss << " struct ethosu_driver *drv = ethosu_reserve_driver();\n"; + ss << " int32_t result = ethosu_invoke(drv, cms_data, cms_data_size, base_addrs, " + "base_addrs_size, " + "num_tensors);\n"; + ss << " ethosu_release_driver(drv);\n"; + if (scratch_size > 0) { + ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n"; + } + ss << " if (result != 0) {\n"; + ss << " return -1;\n"; + ss << " } else {\n"; + ss << " return 0;\n"; + ss << " }\n"; + ss << "}\n"; + ss << "\n"; + PrintExternCPostfix(ss); + ss << "\n"; + PrintExternCPrefix(ss); + ss << "// Wrapper function is provided to allow for easier debugging\n"; + ss << "inline static int32_t " + func_no_dashes + "_wrapper_(void* input, void* output) {\n"; + ss << " size_t input_data_size = " << input_size << ";\n"; + ss << " size_t output_data_size = " << output_size << ";\n"; + ss << " return " + func_no_dashes + + "_((int8_t*)input, input_data_size, (int8_t*)output, output_data_size);\n"; + ss << "}\n"; + PrintExternCPostfix(ss); + ss << "\n"; + PrintExternCPrefix(ss); + PrintRuntimeFunctionHeader(ss, func_names_[0]); + EnterScope(); + PrintIndents(ss); + ss << "return " << func_no_dashes << "_wrapper_(input, output);\n"; + ExitScope(); + ss << "}\n"; + PrintExternCPostfix(ss); + + return ss.str(); + } +}; + +class EthosUModule : public Module { + public: + EthosUModule() {} + explicit EthosUModule(ObjectPtr n) : Module(n) {} + /*! \return internal container */ + inline EthosUModuleNode* operator->(); + /*! \return internal container */ + inline const EthosUModuleNode* operator->() const; +}; + +inline EthosUModuleNode* EthosUModule::operator->() { + return static_cast(get_mutable()); +} + +TVM_REGISTER_GLOBAL("runtime.module.ethosu.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = EthosUModuleNode::Create(args[0], args[1], args[2], args[3], args[4], args[5]); +}); + +TVM_REGISTER_GLOBAL("runtime.module.ethosu.getcs").set_body_typed([](EthosUModule mod) { + return mod->GetCS(); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/relay/backend/name_transforms.cc b/src/relay/backend/name_transforms.cc new file mode 100644 index 0000000000000..a6d10a795cf7a --- /dev/null +++ b/src/relay/backend/name_transforms.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "name_transforms.h" + +#include + +#include +#include + +namespace tvm { +namespace relay { +namespace backend { + +std::string ToCFunctionStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Function name is empty"; + ICHECK_EQ(original_name.find("TVM"), 0) << "Function not TVM prefixed"; + + int tvm_prefix_length = 3; + std::string function_name("TVM"); + + bool new_block = true; + for (const char& symbol : original_name.substr(tvm_prefix_length)) { + if (std::isalpha(symbol)) { + if (new_block) { + function_name.push_back(std::toupper(symbol)); + new_block = false; + } else { + function_name.push_back(std::tolower(symbol)); + } + } else if (symbol == '_') { + new_block = true; + } + } + return function_name; +} + +std::string ToCVariableStyle(const std::string& original_name) { + ICHECK(!original_name.empty()) << "Variable name is empty"; + ICHECK_EQ(original_name.find("TVM"), 0) << "Variable not TVM prefixed"; + + std::string variable_name; + variable_name.resize(original_name.size()); + + std::transform(original_name.begin(), original_name.end(), variable_name.begin(), ::tolower); + return variable_name; +} + +std::string CombineNames(const Array& names) { + std::stringstream combine_stream; + ICHECK(!names.empty()) << "Name segments empty"; + + for (const String& name : names) { + ICHECK(!name.empty()) << "Name segment is empty"; + combine_stream << name << "_"; + } + + std::string combined_name = combine_stream.str(); + combined_name.pop_back(); + return combined_name; +} + +std::string SanitizeName(const std::string& name) { + ICHECK(!name.empty()) << "Name is empty"; + + auto multipleSeparators = [](char before, char after) { + return before == '_' && before == after; + }; + auto isNotAlnum = [](char c) { return !std::isalnum(c); }; + std::string sanitized_input = name; + std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_'); + + sanitized_input.erase( + std::unique(sanitized_input.begin(), sanitized_input.end(), multipleSeparators), + sanitized_input.end()); + + return sanitized_input; +} + +TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle); +TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle); +TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName); +TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName); +TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName); + +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/name_transforms.h b/src/relay/backend/name_transforms.h new file mode 100644 index 0000000000000..4c1fd3ae56fc5 --- /dev/null +++ b/src/relay/backend/name_transforms.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/name_transforms.h + * \brief Transformations which are applied on names to generate appropriately named compiler + * artifacts + * + * Example: + * ToCFunctionStyle(PrefixName(CombineNames({"Device", "target", "Invoke"}))) + * // TVMDeviceTargetInvoke + * + * ToCFunctionStyle(PrefixGeneratedName(CombineNames({"model", "Run"}))) + * // TVMGenModelRun + * + * ToCVariableStyle(PrefixName(CombineNames({"Device", "target", "t"}))) + * // tvm_device_target_t + * + * ToCVariableStyle(PrefixGeneratedName(CombineNames({"model", "Devices"}))) + * // tvmgen_model_devices + * + */ + +#include +#include +#include + +#include +#include +#include + +#ifndef TVM_RELAY_BACKEND_NAME_TRANSFORMS_H_ +#define TVM_RELAY_BACKEND_NAME_TRANSFORMS_H_ + +namespace tvm { +namespace relay { +namespace backend { + +/*! + * \brief Transform a name to the C variable style assuming it is + * appropriately constructed using the prefixing functions + * \param original_name Original name + * \return Transformed function in the C function style + */ +std::string ToCFunctionStyle(const std::string& original_name); + +/*! + * \brief Transform a name to the C variable style assuming it is + * appropriately constructed using the prefixing functions + * \param name Original name + * \return Transformed function in the C variable style + */ +std::string ToCVariableStyle(const std::string& original_name); + +/*! + * \brief Combine names together for use as a generated name + * \param names Vector of strings to combine + * \return Combined together names + */ +std::string CombineNames(const Array& names); + +/*! + * \brief Apply TVM-specific prefix to a name + * \param names Vector of names to combine to form a combined name + * \return Name with prefix applied or prefix-only if no name passed + */ +inline std::string PrefixName(const Array& names) { return "TVM_" + CombineNames(names); } + +/*! + * \brief Apply generated TVM-specific prefix to a name + * \param names Vector of names to combine to form a combined name + * \return Name with prefix applied or prefix-only if no name passed + */ +inline std::string PrefixGeneratedName(const Array& names) { + return "TVMGen_" + CombineNames(names); +} + +/*! + * \brief Sanitize name for output into compiler artifacts + * \param name Original name + * \return Sanitized name + */ +std::string SanitizeName(const std::string& name); + +} // namespace backend +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_NAME_TRANSFORMS_H_ diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e322ccaff1cee..d37fbeabc2775 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -439,11 +439,10 @@ class LowerTensorExprMutator : public ExprMutator { } // Non-External Relay Function - DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n" - << PrettyPrint(func); + VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func); CCacheKey key = CCacheKey(func, target); CachedFunc lowered_func = compiler_->Lower(key, module_name_); - DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; + VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'"; // Collect all the lowered functions produced for this primitive function. Map prim_fns; @@ -452,8 +451,7 @@ class LowerTensorExprMutator : public ExprMutator { CHECK(prim_fn.second.as()) << "must be a prim fn"; prim_fns.Set(prim_fn.first, Downcast(prim_fn.second)); all_prim_fn_vars.push_back(prim_fn.first); - DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) - << "'"; + VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'"; } // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT @@ -559,7 +557,7 @@ class LowerTensorExprMutator : public ExprMutator { // Already lowered by other means so we don't need to mutate // the call if (prim_func->IsInstance()) { - return expr; + return std::move(expr); } // Find the desired target device. @@ -859,23 +857,7 @@ IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_con auto updated_module = LowerTensorExpr(targets, device_context_map, module_name, compiler, process_fn)(module); - // A temporary solution until we can rewrite the auto-scheduler task extraction code to work - // in a more reasonable way. - if (backend::IsAutoSchedulerEnabled()) { - const auto* te_compiler_update_weights = - runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights"); - - ICHECK(te_compiler_update_weights != nullptr) - << "auto_scheduler.relay_integration.te_compiler_update_weights"; - - Map weight_map; - - for (auto pair : compiler->GetOpWeights()) { - weight_map.Set(pair.first, pair.second); - } - - (*te_compiler_update_weights)(weight_map); - } + backend::UpdateAutoSchedulerOpWeights(compiler); // Copy the lowered functions into the return module updated_module->Update(compiler->GetLoweredFunctions()); diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 07dfe1768790d..67c7558889fbe 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -26,6 +26,8 @@ #include +#include "te_compiler.h" + namespace tvm { namespace relay { namespace backend { @@ -227,6 +229,23 @@ Map TargetStrModuleMapToTargetModuleMap( return tvm_map; } +void UpdateAutoSchedulerOpWeights(tec::TECompiler compiler) { + if (IsAutoSchedulerEnabled()) { + const auto* te_compiler_update_weights = + runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights"); + + ICHECK(te_compiler_update_weights != nullptr) + << "auto_scheduler.relay_integration.te_compiler_update_weights"; + + Map weight_map; + + for (auto pair : compiler->GetOpWeights()) { + weight_map.Set(pair.first, pair.second); + } + (*te_compiler_update_weights)(weight_map); + } +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index ae8d7d2c23603..f8ff20ece5616 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -44,6 +44,11 @@ namespace tvm { namespace relay { + +namespace tec { +class TECompiler; +} + namespace transform { Pass InlinePrimitives(); } @@ -492,6 +497,15 @@ TargetModuleMapToTargetStrModuleMap(Map input_map); Map TargetStrModuleMapToTargetModuleMap( std::unordered_map input_map); +/*! + * \brief Call "weight update callback" to communicate op weights seen during Relay module + * lowering back to the auto scheduler. + * Op weights refer to the number of times each distinct op/workload appears in a given module. + * It is called "use_count" in TECompiler. + * \param TECompiler used in the Relay module lowering step. + */ +void UpdateAutoSchedulerOpWeights(tec::TECompiler compiler); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3eab91d202c2..723a0ea6ee7ed 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -977,6 +977,8 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe for (const auto& cfunc : context_.cached_funcs) { exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++}); } + + backend::UpdateAutoSchedulerOpWeights(context_.compiler); } transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 05fb2a1206208..6924f2598f6fe 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -136,13 +136,13 @@ struct PrimitiveInliner : ExprMutator { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); + VLOG(1) << "Before inlining primitives: " << global << std::endl << PrettyPrint(func); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(global, func, true); - DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false); + VLOG(1) << "After inlining primitives: " << global << std::endl << PrettyPrint(func); } } return module_; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 851a498377b2d..89f22cfb25b21 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -771,6 +771,8 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") }); Expr PatternRewriter::Rewrite(const Array& callbacks, const Expr& pre) { + VLOG_CONTEXT << "PatternRewriter"; + VLOG(1) << "rewriting:" << std::endl << PrettyPrint(pre); auto post = pre; auto last = post; // rewrite the graph until it stops changing to make sure all rewrites are complete @@ -789,7 +791,9 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E groups_ = grouper.GroupMatches(callback_->pattern, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); + VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); post = this->VisitExpr(post); + VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); count++; } equal = (*structural_equal)(last, post, false, true); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 344d1cae78237..eacd3783d9b12 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -130,8 +130,9 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) ICHECK(mod.defined()); - DLOG(INFO) << "Executing function pass : " << pass_info->name - << " with opt level: " << pass_info->opt_level; + VLOG_CONTEXT << pass_info->name; + VLOG(0) << "Executing function pass with opt level: " << pass_info->opt_level; + VLOG(1) << "Input module:" << std::endl << PrettyPrint(mod); IRModule updated_mod = mod->ShallowCopy(); @@ -155,6 +156,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) pass_ctx->diag_ctx.value().Render(); pass_ctx->diag_ctx = previous; + VLOG(1) << "Output module:" << std::endl << PrettyPrint(updated_mod); + // TODO(@jroesch): move away from eager type checking for performance reasons // make issue. return transform::InferType()(updated_mod); diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index b59c5a3e9ff3f..284f8b88ee0d6 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -20,9 +20,11 @@ /*! * * \file src/relay/op/annotation/annotation.cc - * \brief Registration of annotation operators. + * \brief Helpers for working with various 'annotations' attributes. */ +#include "./annotation.h" + #include #include #include @@ -36,15 +38,51 @@ namespace tvm { namespace relay { -// relay.annotation.on_device TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); +const Op& OnDeviceOp() { + static const Op& op = Op::Get("on_device"); + return op; +} + +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + auto attrs = make_object(); + attrs->device_type = device_type; + attrs->is_fixed = is_fixed; + Span span = expr->span; + return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span); +} + +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) { + if (device_type == kInvalidDeviceType) { + // Undefined signals no annotation is required. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // These operators are device polymorphic so no annotation is required. + // TODO(mbs): The device planning pass does NOT currently support device polymorphism for + // constructors, so we could remove them from this condition. However most constructors + // accept type parameters, and it is not well-formed Relay to simply wrap such a + // constructor in an "on_device" call. So we'll pretend they are device polymorphic to + // avoid that difficultly. Overall ADTs need more work to be fully supported. + return expr; + } + if (expr->IsInstance() || expr->IsInstance()) { + // The device can be recovered from the binding site of the global or local variable. + return expr; + } + if (const auto* function_node = expr.as()) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // Primitive functions are device polymorphic, matching our interpretation for OpNode above. + return expr; + } + } + return OnDevice(expr, device_type, is_fixed); +} + TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") - .set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); + .set_body_typed([](Expr expr, int device_type, bool is_fixed) { + return OnDevice(expr, static_cast(device_type), is_fixed); }); RELAY_REGISTER_OP("on_device") @@ -53,15 +91,101 @@ RELAY_REGISTER_OP("on_device") .add_argument("data", "Tensor", "The input data.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.OnDeviceAttrs") .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("TNonComputational", true) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) -> Array { return {topi::identity(inputs[0])}; }); +OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + if (call_node->op == OnDeviceOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; + ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; + const auto* on_device_attrs = call_node->attrs.as(); + ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; + auto device_type = static_cast(on_device_attrs->device_type); + // Follow nesting: + // on_device(on_device(expr, device_type=1), device_type=2) == {expr, 1} + auto inner = GetOnDeviceProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.device_type, on_device_attrs->is_fixed || inner.is_fixed}; + } else { + return {call_node->args[0], device_type, on_device_attrs->is_fixed}; + } + } + return {}; +} + +OnDeviceProps GetOnDeviceProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetOnDeviceProps(call_node); + } + return {}; +} + +Function FunctionOnDevice(Function function, Array param_device_types, + Integer result_device_type) { + return WithAttrs(std::move(function), {{tvm::attr::kParamDeviceTypes, param_device_types}, + {tvm::attr::kResultDeviceType, result_device_type}}); +} + +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType result_device_type) { + Array arr; + arr.reserve(param_device_types.size()); + for (const auto device_type : param_device_types) { + arr.push_back(static_cast(device_type)); + } + return FunctionOnDevice(std::move(function), std::move(arr), + static_cast(result_device_type)); +} + +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type) { + if (std::all_of(param_device_types.begin(), param_device_types.end(), + [](DLDeviceType type) { return type == kInvalidDeviceType; }) && + result_device_type == kInvalidDeviceType) { + return function; + } + return FunctionOnDevice(function, param_device_types, result_device_type); +} + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.function_on_device") + .set_body_typed([](Function function, Array param_device_types, + int result_device_type) { + return FunctionOnDevice(function, param_device_types, + static_cast(result_device_type)); + }); + +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node) { + auto opt_integer = function_node->GetAttr(tvm::attr::kResultDeviceType); + if (!opt_integer) { + // No annotation. + return kInvalidDeviceType; + } + return static_cast(opt_integer.value()->value); +} + +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i) { + ICHECK_LT(i, function_node->params.size()) + << "param index " << i << " out of range for function of arity " + << function_node->params.size(); + auto opt_array = function_node->GetAttr>(tvm::attr::kParamDeviceTypes); + if (!opt_array) { + // No annotation. + return kInvalidDeviceType; + } + ICHECK_EQ(opt_array.value().size(), function_node->params.size()) + << "annotation parameters do not match function arity"; + return static_cast(opt_array.value()[i]->value); +} + Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h new file mode 100644 index 0000000000000..35f8b6bf50b66 --- /dev/null +++ b/src/relay/op/annotation/annotation.h @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/annotation/annotation.h + * \brief Helpers for working with various 'annotation' attributes. + */ +#ifndef TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ +#define TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "on_device" operator. */ +const Op& OnDeviceOp(); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. + * + * See \p OnDeviceAttrs for an overview. + */ +Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed. However + * returns \p expr directly if: + * - \p device_type is \p kInvalidDeviceType, which signals there are no device annotations + * already in play. + * - \p expr is an operator or primitive function literal. These are device polymorphic. + * - \p expr is a global or local var. These already have an implied device. + * - \p expr is a constructor. There should probably be device polymorphic but are in an + * in-between state at the moment. + */ +Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + DLDeviceType device_type = kInvalidDeviceType; + bool is_fixed = false; + + OnDeviceProps() = default; + + OnDeviceProps(const Expr& body, DLDeviceType deviceType, bool isFixed) + : body(body), device_type(deviceType), is_fixed(isFixed) {} +}; + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p call_node if it is + * an "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p + * false. + */ +OnDeviceProps GetOnDeviceProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, device type and is_fixed field for \p expr if it is an + * "on_device" CallNode. Otherwise returns the null expression, \p kInvalidDeviceType and \p false. + */ +OnDeviceProps GetOnDeviceProps(const Expr& expr); + +/*! + * \brief Returns \p function annotated with "param_device_types" and "result_device_type" + * attributes capturing parameter and result devices types respectively. + */ +Function FunctionOnDevice(Function function, Array param_device_types, + Integer body_device_type); +Function FunctionOnDevice(Function function, const std::vector& param_device_types, + DLDeviceType body_device_type); + +/*! + * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and + * result device types are \p kInvalidDeviceType. + */ +Function MaybeFunctionOnDevice(Function function, + const std::vector& param_device_types, + DLDeviceType result_device_type); + +/*! + * \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType + * if function does not have "result_device_type" annotation. + */ +DLDeviceType GetFunctionResultDeviceType(const FunctionNode* function_node); + +/*! + * \brief Returns the device type for the \p i'th parameter of \p function_node, or + * \p kInvalidDeviceType if function does not have "param_device_types" annotation. + */ +DLDeviceType GetFunctionParamDeviceType(const FunctionNode* function_node, size_t i); + +/*! \brief Wraps \p data in a "stop_fusion" annotation. */ +Expr StopFusion(Expr data); + +/*! \brief Wraps \p data in a "cast_hint" annotation for \p dtype. */ +Expr CastHint(Expr data, DataType dtype); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_ANNOTATION_ANNOTATION_H_ diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index d8ee1c84a99c3..64baa6066522c 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -618,6 +618,137 @@ RELAY_REGISTER_OP("dyn.sparse_to_dense") .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr("FTVMCompute", SparseToDenseCompute); +/* relay.dyn.unsqueeze */ +TVM_REGISTER_NODE_TYPE(DynExpandDimsAttrs); + +bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(num_inputs, 2); + const auto* data_type = types[0].as(); + if (data_type == nullptr) { + ICHECK(types[0].as()) + << "expand_dims: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto* param = attrs.as(); + + // We don't know the output shape until we see the value of the axis input + int ndim = data_type->shape.size(); + Array oshape(ndim + param->num_newaxis, Any()); + + const auto* axis_type = types[1].as(); + ICHECK(axis_type->shape.size() == 0) << "Axis should be a scalar got shape " << axis_type->shape; + + // Set output shape + reporter->Assign(types[2], TensorType(oshape, data_type->dtype)); + return true; +} + +Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + // inputs = [Input tensor, axis to expand] + ICHECK_EQ(inputs.size(), 2); + + const auto* param = attrs.as(); + + Array ishape = inputs[0]->shape; + const TensorTypeNode* out_ttype = out_type.as(); + int ndim_out = out_ttype->shape.size(); + int ndim_in = ishape.size(); + ICHECK_EQ(ndim_in + param->num_newaxis, ndim_out); + + Array newshape; + for (auto val : out_ttype->shape) { + // These vars will be populated by the VM executor with the results + // of the shape_func for the op. + newshape.push_back(val.as()->ToVar()); + } + + return {topi::reshape(inputs[0], newshape)}; +} + +Expr MakeExpandDims(Expr data, Expr axis_tensor, int num_newaxis) { + auto attrs = make_object(); + attrs->num_newaxis = num_newaxis; + static const Op& op = Op::Get("dyn.expand_dims"); + return Call(op, {data, axis_tensor}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.expand_dims").set_body_typed(MakeExpandDims); + +RELAY_REGISTER_OP("dyn.expand_dims") + .describe(R"code(Insert one new axis at the position given by `axis` + +- **data**: The input data to the operator. +- **axis**: The axis to insert a new dimension + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("axis", "Tensor", "The axis to insert at a dimension.") + .set_support_level(3) + .add_type_rel("DynamicExpandDims", ExpandDimsRel) + .set_attr("FTVMCompute", ExpandDimsCompute) + .set_attr("TOpPattern", kInjective); + +bool DynSqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // [data, axes, output] + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* axes = types[1].as(); + if (axes == nullptr) { + return false; + } + + ICHECK_EQ(axes->shape.size(), 1) << "Got" << axes->shape.size() << "expected 1"; + ICHECK(axes->shape[0].as()) << "axes expected to be static rank"; + size_t output_rank = data->shape.size() - axes->shape[0].as()->value; + std::vector result_shape(output_rank, Any()); + reporter->Assign(types[2], TensorType(result_shape, data->dtype)); + return true; +} + +Array SqueezeCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* out_ttype = out_type.as(); + ICHECK(out_ttype != nullptr); + Array newshape; + for (auto val : out_ttype->shape) { + newshape.push_back(val.as()->ToVar()); + } + return {topi::reshape(inputs[0], newshape)}; +} + +Expr MakeDynSqueeze(Expr data, Expr axes) { + auto attrs = make_object(); + static const Op& op = Op::Get("dyn.squeeze"); + return Call(op, {data, axes}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.squeeze").set_body_typed(MakeDynSqueeze); + +RELAY_REGISTER_OP("dyn.squeeze") + .describe(R"code(Remove axes of value 1 in input tensor at the dimensions given by axes + +- **data**: The input data to the operator. +- **axes**: The axes to squeeze. + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("axes", "Tensor", "The axes to squeeze.") + .set_support_level(3) + .add_type_rel("DynSqueeze", DynSqueezeRel) + .set_attr("FTVMCompute", SqueezeCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("TReshapeOp", true); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc new file mode 100644 index 0000000000000..dce89aa91b65a --- /dev/null +++ b/src/relay/op/memory/device_copy.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/device_copy.cc + * \brief Helpers for working with "device_copy" attributes. + */ + +#include "./device_copy.h" + +#include +#include +#include +#include +#include + +#include "../../transforms/infer_layout_utils.h" +#include "../type_relations.h" + +namespace tvm { +namespace relay { + +// relay.device_copy +TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); + +const Op& DeviceCopyOp() { + static const Op& op = Op::Get("device_copy"); + return op; +} + +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + Span span = expr->span; + return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(attrs), /*type_args=*/{}, span); +} + +Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + if (src_dev_type == dst_dev_type) { + return expr; + } + ICHECK_NE(src_dev_type, kInvalidDeviceType); + ICHECK_NE(dst_dev_type, kInvalidDeviceType); + return DeviceCopy(expr, src_dev_type, dst_dev_type); +} + +TVM_REGISTER_GLOBAL("relay.op._make.device_copy") + .set_body_typed([](Expr expr, int src_dev_type, int dst_dev_type) { + return DeviceCopy(expr, static_cast(src_dev_type), + static_cast(dst_dev_type)); + }); + +RELAY_REGISTER_OP("device_copy") + .describe(R"code( +Copy data from one tensor to another. The source and destination might be +on different devices. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attrs_type_key("relay.attrs.DeviceCopyAttrs") + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { + if (call_node->op == DeviceCopyOp()) { + ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument"; + ICHECK(call_node->attrs.defined()) << "device_copy requires attributes"; + const auto* device_copy_attrs = call_node->attrs.as(); + ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; + auto src_dev_type = static_cast(device_copy_attrs->src_dev_type); + auto dst_dev_type = static_cast(device_copy_attrs->dst_dev_type); + // Follow nesting: + // device_copy(device_copy(expr, src_dev_type=1, dst_dev_type=2), + // src_dev_type=2, dst_dev_type=3) ==> {expr, 1, 3} + auto inner = GetDeviceCopyProps(call_node->args[0]); + if (inner.body.defined()) { + return {inner.body, inner.src_dev_type, inner.dst_dev_type}; + } else { + return {call_node->args[0], src_dev_type, dst_dev_type}; + } + } + return {}; +} + +DeviceCopyProps GetDeviceCopyProps(const Expr& expr) { + if (const auto* call_node = expr.as()) { + return GetDeviceCopyProps(call_node); + } + return {}; +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h new file mode 100644 index 0000000000000..d21fdb6abe198 --- /dev/null +++ b/src/relay/op/memory/device_copy.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/op/memory/device_copy.h + * \brief Helpers for working with "device_copy" attributes. + */ + +#ifndef TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ +#define TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Returns the "device_copy" operator. */ +const Op& DeviceCopyOp(); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + */ +Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! + * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated on + * a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type. + * However, return \p expr directly if \p src_dev_type equals \p dst_dev_type. + */ +Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type); + +/*! \brief Result of \p GetDeviceCopyProps. */ +struct DeviceCopyProps { + Expr body; // = null + DLDeviceType src_dev_type = kInvalidDeviceType; + DLDeviceType dst_dev_type = kInvalidDeviceType; + + DeviceCopyProps() = default; + + DeviceCopyProps(const Expr& body, DLDeviceType srcDevType, DLDeviceType dstDevType) + : body(body), src_dev_type(srcDevType), dst_dev_type(dstDevType) {} +}; + +/*! + * \brief Returns the body expression, source, and destination device types for \p call_node if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); + +/*! + * \brief Returns the body expression, source, and destination device types for \p expr if it + * is a "device_copy" CallNode. Otherwise returns the null expression and \p kInvalidDeviceType + * device types. + */ +DeviceCopyProps GetDeviceCopyProps(const Expr& expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_MEMORY_DEVICE_COPY_H_ diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c2997fb6cf958..5339d48e3a2f1 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -35,9 +35,9 @@ #include #include "../../transforms/infer_layout_utils.h" +#include "../annotation/annotation.h" #include "../op_common.h" #include "../type_relations.h" -#include "tvm/relay/attrs/device_copy.h" namespace tvm { namespace relay { @@ -86,6 +86,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") .add_argument("size", "Tensor", "The size of the storage to allocate.") .add_argument("alignment", "Tensor", "The alignment of the storage.") .add_type_rel("AllocStorage", AllocStorageRel) + .set_attrs_type_key("relay.attrs.AllocStorageAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -97,14 +98,21 @@ RELAY_REGISTER_OP("memory.alloc_storage") return {topi::identity(inputs[0])}; }); -Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, +Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; if (assert_shape.defined()) { attrs->assert_shape = assert_shape; } else { - attrs->const_shape = Downcast(shape); + // Look through any on_device for the shape argument expression. + Expr literal_shape = shape; + auto props = GetOnDeviceProps(literal_shape); + if (props.body.defined()) { + // See through on_device calls. + literal_shape = props.body; + } + attrs->const_shape = Downcast(literal_shape); } static const Op& op = Op::Get("memory.alloc_tensor"); return Call(op, {storage, offset, shape}, Attrs(attrs), {}); @@ -193,6 +201,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor") .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) + .set_attrs_type_key("relay.attrs.AllocTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -307,36 +316,5 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType") return ToTupleType(t, std::vector(array.begin(), array.end())); }); -// relay.device_copy -TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); - -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.device_copy").set_body_typed(DeviceCopy); - -RELAY_REGISTER_OP("device_copy") - .describe(R"code( -Copy data from one tensor to another. The source and destination might be -on different devices. -)code" TVM_ADD_FILELINE) - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") - .set_support_level(10) - .add_type_rel("Identity", IdentityRel) - .set_attr("TOpPattern", kOpaque) - .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) - .set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index bbbd11867549d..558c409782f57 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -33,7 +33,6 @@ namespace tvm { namespace relay { Expr AllocStorage(Expr size, Expr alignment, Device dev, DataType dtype_hint); -Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape); Expr ToTupleType(const Type& ty, const std::vector& exprs); diff --git a/src/relay/op/nn/convolution_make.h b/src/relay/op/nn/convolution_make.h index 01d6f183f79e5..d343940b9ca75 100644 --- a/src/relay/op/nn/convolution_make.h +++ b/src/relay/op/nn/convolution_make.h @@ -18,7 +18,7 @@ */ /*! - * \file src/relay/op/nn/make_convolution.h + * \file src/relay/op/nn/convolution_make.h * \brief utilities for creating convolution ops */ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_MAKE_H_ diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 693589fecfb46..c9f14c91c7b1a 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -116,13 +116,14 @@ Array GetExcludeAxes(size_t indim, const Array& inaxis) { } // Return the modified layout for AlterOpLayout pass. +template InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - const auto* attrs_ptr = attrs.as(); + const auto* attrs_ptr = attrs.as(); ICHECK(attrs_ptr); - ObjectPtr params = make_object(*attrs_ptr); + ObjectPtr params = make_object(*attrs_ptr); // Get the reduce axes. Array> old_in_shapes; @@ -152,11 +153,14 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, for (auto iter_var : layout->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); - if (old_r_dims.count(layout_dim)) { - new_r_axes.push_back(tvm::Integer(axis_index)); - } // Collect only the primal axis. if (layout_axis.IsPrimal()) { + if (old_r_dims.count(layout_dim) && !params->exclude) { + new_r_axes.push_back(tvm::Integer(axis_index)); + } + if (!old_r_dims.count(layout_dim) && params->exclude) { + new_r_axes.push_back(tvm::Integer(axis_index)); + } if (!old_r_dims.count(layout_dim) || params->keepdims) { inferred_out_string += layout_dim; } @@ -171,18 +175,24 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, std::string new_layout_string; Array new_r_axes; + Array new_input_layouts; + + auto check_num_input_layouts = [](Array in_layouts) { + // The second case is for variance op + ICHECK(in_layouts.size() == 1 || in_layouts.size() == 2); + }; if (new_in_layouts.defined() && r_axes.size()) { // Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the // modified layout axes. - ICHECK_EQ(new_in_layouts.size(), 1); - ICHECK_EQ(old_in_layouts.size(), 1); + check_num_input_layouts(new_in_layouts); + check_num_input_layouts(old_in_layouts); // Get inferred_in and inferred_out from new_in_layout. std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]); params->axis = new_r_axes; } else if (old_in_layouts.defined()) { - ICHECK_EQ(old_in_layouts.size(), 1); + check_num_input_layouts(old_in_layouts); // If the new layout is undefined, get inferred_in and inferred_out from old_in_layout. if (old_in_layouts[0].defined()) { @@ -190,7 +200,13 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, } } - return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params)); + new_input_layouts.push_back(inferred_in); + + if (old_in_layouts.size() == 2) { + new_input_layouts.push_back(inferred_in); + } + + return InferCorrectLayoutOutput(new_input_layouts, {inferred_out}, Attrs(params)); } template @@ -389,6 +405,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ArgMinCompute(const Attrs& attrs, const Array& inputs, @@ -405,6 +422,7 @@ values over a given axis. .set_support_level(4) .add_type_rel("ArgReduce", GenericReduceRel) .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array SumCompute(const Attrs& attrs, const Array& inputs, @@ -433,7 +451,7 @@ Example:: .set_attrs_type() .set_support_level(4) .add_type_rel("Reduce", ReduceRel) - .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("FTVMCompute", SumCompute) .set_attr("TOpPattern", kCommReduce); @@ -468,6 +486,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", AllCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array AnyCompute(const Attrs& attrs, const Array& inputs, @@ -516,6 +535,7 @@ RELAY_REGISTER_REDUCE_OP("max") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MaxCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MinCompute(const Attrs& attrs, const Array& inputs, @@ -531,6 +551,7 @@ RELAY_REGISTER_REDUCE_OP("min") .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MinCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array ProdCompute(const Attrs& attrs, const Array& inputs, @@ -551,10 +572,10 @@ Example:: [[1,4],[4,3],[5,2]], [[7,1],[7,2],[7,3]]] - mean(data, axis=1) + prod(data, axis=1) [35562240] - mean(data, axis=[1,2]) + prod(data, axis=[1,2]) [ 36 480 2058] )code" TVM_ADD_FILELINE) @@ -562,6 +583,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", ProdCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); Array MeanCompute(const Attrs& attrs, const Array& inputs, @@ -600,6 +622,7 @@ Example:: .set_support_level(4) .add_type_rel("Reduce", ReduceRel) .set_attr("FTVMCompute", MeanCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -675,6 +698,7 @@ RELAY_REGISTER_OP("variance") .add_argument("mean", "Tensor", "The mean tensor.") .add_type_rel("Variance", VarianceRel) .set_attr("FTVMCompute", VarianceCompute) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) .set_attr("TOpPattern", kCommReduce); } // namespace relay diff --git a/src/relay/op/vm/vm.cc b/src/relay/op/vm/vm.cc index a74a259a114f6..be31b54829379 100644 --- a/src/relay/op/vm/vm.cc +++ b/src/relay/op/vm/vm.cc @@ -50,6 +50,7 @@ RELAY_REGISTER_OP("vm.shape_of") .set_num_inputs(1) .add_argument("tensor", "Tensor", "The input tensor") .add_type_rel("ShapeOf", ShapeOfRel) + .set_attrs_type_key("relay.attrs.ShapeOfAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) @@ -131,6 +132,7 @@ RELAY_REGISTER_OP("vm.shape_func") .add_argument("func", "Function", "The operation to call") .add_argument("ins", "Tuple", "The input tensors.") .add_argument("outs", "Tuple", "The output tensors.") + .set_attrs_type_key("relay.attrs.ShapeFuncAttrs") .add_type_rel("ShapeFuncRel", ShapeFuncRel) .set_support_level(10) .set_attr("TOpPattern", kOpaque) @@ -214,6 +216,7 @@ RELAY_REGISTER_OP("vm.reshape_tensor") .add_argument("data", "Tensor", "The input tensor") .add_argument("shape", "Tensor", "The output shape tensor") .add_type_rel("ReshapeTensor", ReshapeTensorRel) + .set_attrs_type_key("relay.attrs.ReshapeTensorAttrs") .set_support_level(10) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index c65cc18799327..6cd596a814acc 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -26,7 +26,7 @@ #include -#include "../transforms/pattern_utils.h" +#include "../op/annotation/annotation.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 968628fbfe39c..e636130f8553d 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -29,8 +29,8 @@ #include #include +#include "../op/annotation/annotation.h" #include "../qnn/utils.h" -#include "../transforms/pattern_utils.h" #include "./quantize.h" namespace tvm { diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 9afdb7210cba0..f347eddae7608 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -50,19 +50,6 @@ namespace alter_op_layout { class AlterTransformMemorizerNode : public TransformMemorizerNode { public: static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode"; -}; - -/*! - * \brief Container that provides the transformation function for alter layout.. - */ -class AlterTransformMemorizer : public TransformMemorizer { - public: - AlterTransformMemorizer() {} - explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - AlterTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } /*! * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by @@ -102,7 +89,23 @@ class AlterTransformMemorizer : public TransformMemorizer { return GetRef(new_call); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } +}; + +/*! + * \brief Container that provides the transformation function for alter layout.. + */ +class AlterTransformMemorizer : public TransformMemorizer { + public: + AlterTransformMemorizer() = default; + explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + AlterTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = AlterTransformMemorizerNode; }; @@ -113,10 +116,12 @@ class AlterTransformMemorizer : public TransformMemorizer { */ Expr AlterOpLayout(const Expr& expr) { // TODO(@icemelon9): need to rerun type inference after applying an alter op. - AlterTransformMemorizer alterMemorizer(make_object()); - auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; }; - - return ForwardRewrite(expr, LayoutRewriter, fcontext); + AlterTransformMemorizer alter_memorizer(make_object()); + std::function fcontext = [=](const Call& call) -> ObjectRef { + return alter_memorizer; + }; + FForwardRewrite rewrite_func = LayoutRewriter; + return ForwardRewrite(expr, rewrite_func, fcontext); } } // namespace alter_op_layout diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index e74ea01158575..e10be508529e0 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -58,22 +58,6 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode { explicit ConvertTransformMemorizerNode(Map> desired_layouts) : desired_layouts_(std::move(desired_layouts)) {} - /*! \brief A mapping of op_name to array of desired layouts for each input. */ - Map> desired_layouts_; -}; - -/*! - * \brief Container that provides the transformation function for convert layout. - */ -class ConvertTransformMemorizer : public TransformMemorizer { - public: - ConvertTransformMemorizer() {} - explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - ConvertTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } - /*! * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the * desired layout as specified by the user. @@ -89,7 +73,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { Expr new_e; bool modified = false; if (fconvert_layout.count(op)) { - auto desired_layouts = operator->()->desired_layouts_; + auto desired_layouts = desired_layouts_; if (desired_layouts.find(op->name) != desired_layouts.end()) { tvm::Array tinfos; for (auto& expr : ref_call->args) { @@ -124,7 +108,26 @@ class ConvertTransformMemorizer : public TransformMemorizer { return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + + /*! \brief A mapping of op_name to array of desired layouts for each input. */ + Map> desired_layouts_; +}; + +/*! + * \brief Container that provides the transformation function for convert layout. + */ +class ConvertTransformMemorizer : public TransformMemorizer { + public: + ConvertTransformMemorizer() = default; + explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + ConvertTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = ConvertTransformMemorizerNode; }; diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index 02f9d474411ab..7457457e4c5c1 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -18,7 +18,7 @@ */ /*! - * \file deivce_annotation.cc + * \file device_annotation.cc * \brief Passes to rewrite annotated program and retrieve the device allocation * of expression. * diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc new file mode 100644 index 0000000000000..204bce53207b5 --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.cc + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#include "./device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace transform { + +// TODO(mbs): We'd probably have less tendious code duplication if we redefined the memoizing +// mutator on top of the generic Functor. + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } + PostVisitLetBlock_(let_node); +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h new file mode 100644 index 0000000000000..8611f87efa06b --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.h + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/annotation/annotation.h" + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +template +class DeviceAwareExprFunctor; + +/*! + * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without + * any memoization. + */ +template <> +class DeviceAwareExprFunctor : public ExprFunctor, + public LexicalOnDeviceMixin { + private: + using TSuper = ExprFunctor; + + public: + void VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } + } + + void VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* visited_let_node = *itr; + PopBoundVar(visited_let_node->var); + PostVisitLet_(visited_let_node); + } + PostVisitLetBlock_(let_node); + } + + void VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } + } + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return TSuper::VisitExpr_(function_node); + } + + virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { + return TSuper::VisitExpr_(call_node); + } + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); + } + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node) {} +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc new file mode 100644 index 0000000000000..15784856edbf5 --- /dev/null +++ b/src/relay/transforms/device_domains.cc @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.cc + * \brief Unification domain for the device planner. + */ + +#include "./device_domains.h" + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + +} // namespace + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. + +size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } +} + +bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } +} + +/* static */ +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } +} + +DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; +} + +DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } +} + +DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; +} + +void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } +} + +DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; +} + +DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + } else if (device_copy_props.body.defined()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x1?, ..., ?xn?):?xr? + // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?. + // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear + // in callee positions. + const auto* func_type_node = call->op->checked_type().as(); + ICHECK_NOTNULL(func_type_node); + ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); + auto result_domain = Free(func_type_node->ret_type); // first-order + for (const auto& arg_type : func_type_node->arg_types) { + auto param_domain = Free(arg_type); // possibly higher-order + UnifyCollapsed(result_domain, param_domain); // collapse if required + args_and_result.emplace_back(param_domain); + } + args_and_result.emplace_back(result_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; +} + +void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; +} + +void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); +} + +void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } +} + +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, + DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); +} + +std::string DeviceDomains::ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); +} + +std::string DeviceDomains::ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h new file mode 100644 index 0000000000000..a29370a0e8077 --- /dev/null +++ b/src/relay/transforms/device_domains.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.h + * \brief Unification domain for the device planner. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const; +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain); + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr); + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call); + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 0000000000000..35bf406263e40 --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "param_device_types" and "result_device_type" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to + * the global default device. Worth a design doc with motivating examples I think. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * attributes for every function (both top-level and local). These describe the devices for + * the function's parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover + * the device for any expression for their own use, e.g. during memory planning. All downstream + * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion + * to ANF must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass can be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" +#include "./device_domains.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +/****** +******* Phase 0 +*******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + * + * \code + * let %x = on_device(e, device_type=d) + * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * + * fn(%x) { on_device(e, device_type=d) } + * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * + * on_device(e).0 + * ==> on_device(e.0) + * \endcode + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** +******* Phase 1 +*******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + * + * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, + * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * so must \p %y. + * + * Constraints can flow in interesting ways. E.g. in: + * \code + * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %g = fn(%f, %x, %y) { %f(%x, %y) } + * %g(%f, %a, %b) + * \endcode + * we discover \p %b must be on device \p d. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + domains_->UnifyExprExact(pair.first, pair.second); + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has device attributes then we can further constrain the + // function's domain to match them. + if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // no-op, constructors are handled at their call-sites. + // TODO(mbs): Assumes eta-expansion + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** +******* Phase 2 +*******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * E.g. in: + * \code + * def @main(%x, %y, %z) { + * let %a = add(%x, %y); + * multiply(%a, on_device(%z, device_type=d)) + * \endcode + * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, + * and the device for the function result, are still 'free'. The global 'default' device type + * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * in turn 'fixes' the device on which the \p add and \p multiply are executed. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** +******* Phase 3 +*******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + * + * For example, we'll end up with programs that look like: + * \code + * def @main(%x, %y, param_device_types=[...], result_device_type=...) { + * let %a = on_device(..., device_type=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), + * src_device_type=..., dst_device_type=...)) + * } + * \endcode + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the function attributes. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + // TODO(mbs): Avoid copy + return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, + call_node->span); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "param_device_types" and "result_device_type" + // attributes to all function definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** +******* Overall composite Pass +*******/ + +// This function is declared in the public . +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 318022fb86f52..751271d2add3d 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -45,6 +45,15 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.squeeze"), + [this](const CallNode* call_node) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* axis = args[1].as()) { + ICHECK_EQ(axis->data->ndim, 1); + return MakeSqueeze(call_node->args[0], ToVector(axis->data)); + } + return Expr(nullptr); + }}, {Op::Get("dyn.tile"), [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 657e2c3924555..31d3b2c8991aa 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -43,14 +43,15 @@ #include "../backend/te_compiler.h" #include "../backend/te_compiler_cache.h" +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" #include "../op/memory/memory.h" #include "../op/vm/vm.h" +#include "./let_list.h" #include "./pass_utils.h" -#include "let_list.h" -#include "pattern_utils.h" +#include "./pattern_utils.h" using namespace tvm::runtime; -using namespace tvm::relay::tec; namespace tvm { namespace relay { @@ -193,7 +194,8 @@ class DialectRewriter : public ExprMutator { private: // Insert a device copy node. Expr DeviceCopy(const Expr& inp, int src_dev, int dst_dev) { - return ExprMutator::Mutate(relay::DeviceCopy(inp, src_dev, dst_dev)); + return ExprMutator::Mutate(relay::DeviceCopy(inp, static_cast(src_dev), + static_cast(dst_dev))); } // Check if a call invokes a primitive function. @@ -274,9 +276,9 @@ class DialectRewriter : public ExprMutator { const std::vector& new_args) { Array shape_func_ins; - TECompiler compiler; + tec::TECompiler compiler; - CCacheKey key(func, target_host_); + tec::CCacheKey key(func, target_host_); auto cfunc = compiler->LowerShapeFunc(key); auto input_states = cfunc->shape_func_param_states; diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index fbb7bc9cd9858..7bfb31a299ad6 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -57,6 +57,21 @@ class TransformMemorizerNode : public Object { } }; + /*! + * \brief Defines the call transformation for derived passes. The new layouts are defined by + * used for different targets using a packed func. + * \param ref_call The original call. + * \param new_attrs Updated attributes consistent with new layouts. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, + const std::vector& new_args) = 0; + + virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + /*! \brief The memorizer map. */ std::unordered_map memo; @@ -69,11 +84,9 @@ class TransformMemorizerNode : public Object { */ class TransformMemorizer : public ObjectRef { public: - TransformMemorizer() {} + TransformMemorizer() = default; explicit TransformMemorizer(ObjectPtr n) : ObjectRef(n) {} - virtual ~TransformMemorizer() {} - TransformMemorizerNode* operator->() { return static_cast(get_mutable()); } @@ -146,19 +159,6 @@ class TransformMemorizer : public ObjectRef { return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); } - /*! - * \brief Defines the call transformation for derived passes. The new layouts are defined by - * used for different targets using a packed func. - * \param ref_call The original call. - * \param new_attrs Updated attributes consistent with new layouts. - * \param new_args The traversed/recursed args to the call. - * \return The new Call after calling the packed func. - */ - virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, - const std::vector& new_args) = 0; - virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { - return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); - } using ContainerType = TransformMemorizerNode; }; @@ -312,7 +312,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj if (ref_call->op.as()) { Op op = Downcast(ref_call->op); if (falter_layout.count(op) && !finfer_layout.count(op)) { - return memorizer.CallWithNewLayouts(ref_call, normal_new_args); + return memorizer->CallWithNewLayouts(ref_call, normal_new_args); } } } @@ -349,7 +349,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } // new_op = alter(op) - Call new_call = memorizer.CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); + Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 6c2371716b167..ebdf1fed2fab5 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -824,7 +824,6 @@ Pass InferType() { auto pass_info = PassInfo(0, "InferType", {}); return tvm::transform::CreateModulePass( [=](IRModule mod, const PassContext& pass_ctx) { - DLOG(INFO) << "tvm::relay::transform::InferType"; // Execute the pass function and return a new module. IRModule updated_mod = mod->ShallowCopy(); diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 5b9f5e17232ca..d1190df913756 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -44,6 +44,32 @@ typedef struct { void** data; } DnnlPackedArgs; +inline dnnl::memory::desc GenDNNLMemDescByShape(const dnnl::memory::dims& shape, + memory::data_type dtype) { + using tag = memory::format_tag; + + dnnl::memory::desc data_md; + + switch (shape.size()) { + case 2: + data_md = dnnl::memory::desc({shape, dtype, tag::ab}); + break; + case 3: + data_md = dnnl::memory::desc({shape, dtype, tag::abc}); + break; + case 4: + data_md = dnnl::memory::desc({shape, dtype, tag::abcd}); + break; + case 5: + data_md = dnnl::memory::desc({shape, dtype, tag::abcde}); + break; + default: + LOG(FATAL) << "Unsupported data shape dimension: " << shape.size(); + break; + } + return data_md; +} + // Read from memory, write to handle inline void read_from_dnnl_memory(void* handle, const memory& mem) { size_t bytes = mem.get_desc().get_size(); @@ -53,8 +79,9 @@ inline void read_from_dnnl_memory(void* handle, const memory& mem) { } void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, - int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, - int p_Kw_, int p_Sh_, int p_Sw_, primitive_attr attr) { + int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, + int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_, + primitive_attr attr) { using tag = memory::format_tag; using dt = memory::data_type; engine eng(engine::kind::cpu, 0); @@ -64,10 +91,11 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; memory::dims conv2d_bias_tz = {p_O_}; - memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_, - (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; + memory::dims conv2d_dst_tz = {p_N_, p_O_, (p_H_ - p_Kh_ + p_Ph0_ + p_Ph1_ + p_Sh_) / p_Sh_, + (p_W_ - p_Kw_ + p_Pw0_ + p_Pw1_ + p_Sw_) / p_Sw_}; memory::dims conv2d_strides = {p_Sh_, p_Sw_}; - memory::dims conv2d_padding = {p_Ph_, p_Pw_}; + memory::dims conv2d_padding0 = {p_Ph0_, p_Pw0_}; + memory::dims conv2d_padding1 = {p_Ph1_, p_Pw1_}; auto user_src_memory = memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); auto user_weights_memory = @@ -81,7 +109,7 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in auto conv2d_desc = convolution_forward::desc( prop_kind::forward_inference, algorithm::convolution_direct, conv2d_src_md, conv2d_weights_md, - conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding, conv2d_padding); + conv2d_bias_md, conv2d_dst_md, conv2d_strides, conv2d_padding0, conv2d_padding1); auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, attr, eng); auto conv2d_src_memory = user_src_memory; @@ -98,12 +126,12 @@ void dnnl_conv2d_common(float* data, float* weights, float* bias, float* out, in } extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_, - int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, - int p_Kw_, int p_Sh_, int p_Sw_) { + int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, + int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) { primitive_attr attr; std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr); + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, attr); } primitive_attr create_attr_with_relu_post_op() { @@ -117,20 +145,23 @@ primitive_attr create_attr_with_relu_post_op() { } extern "C" void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_, - int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, - int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_) { + int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, + int p_Pw0_, int p_Ph1_, int p_Pw1_, int p_Kh_, int p_Kw_, + int p_Sh_, int p_Sw_) { std::vector bias(p_O_, 0); return dnnl_conv2d_common(data, weights, bias.data(), out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, - p_Ph_, p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, + p_Ph0_, p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op()); } extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, - int p_G_, int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, - int p_Sh_, int p_Sw_) { - return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph_, - p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op()); + int p_G_, int p_Ph0_, int p_Pw0_, int p_Ph1_, + int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, + int p_Sw_) { + return dnnl_conv2d_common(data, weights, bias, out, p_N_, p_C_, p_H_, p_W_, p_O_, p_G_, p_Ph0_, + p_Pw0_, p_Ph1_, p_Pw1_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, + create_attr_with_relu_post_op()); } extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { @@ -170,16 +201,13 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { - using tag = memory::format_tag; +extern "C" void dnnl_relu(float* data, float* out, std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); @@ -236,27 +264,39 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { - using tag = memory::format_tag; +// should comply with src/relay/backend/contrib/dnnl/codegen.cc +#define DNNL_BINARY_ADD 0 +#define DNNL_BINARY_MUL 1 + +extern "C" void dnnl_binary_op(float* data, float* weight, float* out, int algo_type, + std::vector shape) { using dt = memory::data_type; engine eng(engine::kind::cpu, 0); stream s(eng); - memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; - - auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; - auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); - auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + auto data_md = GenDNNLMemDescByShape(shape, dt::f32); auto data_memory = memory(data_md, eng, data); - auto weight_memory = memory(weight_md, eng, weight); - auto dst_memory = memory(dst_md, eng); + auto weight_memory = memory(data_md, eng, weight); + auto dst_memory = memory(data_md, eng); - auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + algorithm algo = algorithm::undef; + switch (algo_type) { + case DNNL_BINARY_ADD: + algo = algorithm::binary_add; + break; + case DNNL_BINARY_MUL: + algo = algorithm::binary_mul; + break; + default: + LOG(FATAL) << "Unsupported dnnl algorithm: " << algo_type; + break; + } + + auto add_desc = binary::desc(algo, data_md, data_md, data_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); - assert(dst_md == add_prim_desc.dst_desc()); + assert(data_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); add.execute( diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index eef67a702d9c6..b32d137a25667 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -113,7 +113,9 @@ class DNNLJSONRuntime : public JSONRuntimeBase { } else if ("nn.relu" == op_name) { Relu(nid); } else if ("add" == op_name) { - Add(nid); + Binary(nid, dnnl::algorithm::binary_add); + } else if ("multiply" == op_name) { + Binary(nid, dnnl::algorithm::binary_mul); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -163,16 +165,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dnnl::memory::dim N = input_shape[0], // batch size IC = input_shape[1], // input channels IH = input_shape[2], // input height - IW = input_shape[2], // input width + IW = input_shape[3], // input width OC = weight_shape[0], // output channels KH = weight_shape[2], // weight height KW = weight_shape[3], // weight width - PH_L = std::stoi(str_padding[1]), // height padding: left - PH_R = std::stoi(str_padding[3]), // height padding: right - PW_L = std::stoi(str_padding[0]), // width padding: left - PW_R = std::stoi(str_padding[2]), // width padding: right + PW_L = std::stoi(str_padding[1]), // width padding: left + PW_R = std::stoi(str_padding[3]), // width padding: right + PH_L = std::stoi(str_padding[0]), // height padding: top + PH_R = std::stoi(str_padding[2]), // height padding: bottom SH = std::stoi(str_strides[0]), // height-wise stride - SW = std::stoi(str_strides[0]), // weight-wise stride + SW = std::stoi(str_strides[1]), // weight-wise stride OH = (IH - KH + PH_L + PH_R) / SH + 1, // output height OW = (IW - KW + PW_L + PW_R) / SW + 1; // output width @@ -338,7 +340,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - auto data_md = dnnl::memory::desc{{shape}, dt::f32, tag::abcd}; + dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, data_md, 0); @@ -349,14 +351,13 @@ class DNNLJSONRuntime : public JSONRuntimeBase { net_.push_back(relu); auto data_memory = BindDNNLMemory(data_entry, data_md); - auto out_md = dnnl::memory::desc(shape, dt::f32, tag::abcd); JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); + auto out_memory = BindDNNLMemory(out_entry, data_md); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); } - void Add(const size_t& nid) { + void Binary(const size_t& nid, dnnl::algorithm algo) { auto node = nodes_[nid]; // Memory and compute description. @@ -378,11 +379,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase { JSONGraphNodeEntry out_entry(nid, 0); auto out_memory = BindDNNLMemory(out_entry, out_md); - auto add_desc = - dnnl::binary::desc(dnnl::algorithm::binary_add, data_mds[0], data_mds[1], out_md); - auto add_prim_desc = dnnl::binary::primitive_desc(add_desc, engine_); - auto add = dnnl::binary(add_prim_desc); - net_.push_back(add); + auto binary_desc = dnnl::binary::desc(algo, data_mds[0], data_mds[1], out_md); + auto binary_prim_desc = dnnl::binary::primitive_desc(binary_desc, engine_); + auto binary = dnnl::binary(binary_prim_desc); + net_.push_back(binary); net_args_.push_back({{DNNL_ARG_SRC_0, data_memories[0]}, {DNNL_ARG_SRC_1, data_memories[1]}, diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index dbc064a6bc993..522313ae5a640 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,9 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include +#include + +#include #include "dnnl.hpp" @@ -36,31 +39,32 @@ namespace contrib { using namespace dnnl; extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, - int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, - int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); + int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph0_, int p_Pw0_, + int p_Ph1_, int p_Pw1_, int p_Kh_, int p_Kw_, int p_Sh_, + int p_Sw_); extern "C" TVM_DLL void dnnl_fused_conv2d_relu(float* data, float* weights, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, - int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, - int p_Sh_, int p_Sw_); + int p_Ph0_, int p_Pw0_, int p_Ph1_, int p_Pw1_, + int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); extern "C" TVM_DLL void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* bias, float* out, int p_N_, int p_C_, int p_H_, - int p_W_, int p_O_, int p_G_, int p_Ph_, - int p_Pw_, int p_Kh_, int p_Kw_, int p_Sh_, - int p_Sw_); + int p_W_, int p_O_, int p_G_, int p_Ph0_, + int p_Pw0_, int p_Ph1_, int p_Pw1_, int p_Kh_, + int p_Kw_, int p_Sh_, int p_Sw_); extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_); -extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, std::vector shape); extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, float* out, float* new_mean, float* new_variance, int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); -extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, - int p_h_, int p_w_); +extern "C" TVM_DLL void dnnl_binary_op(float* data, float* weight, float* out, int binary_algo, + std::vector shape); } // namespace contrib } // namespace runtime diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 23f7339605dfa..c60928e95db41 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -61,9 +61,6 @@ TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, this->calibrator_ = calibrator; if (calibrator != nullptr) { use_int8_ = true; - builder_->setFp16Mode(true); - builder_->setInt8Mode(true); - builder_->setInt8Calibrator(calibrator); } network_ = builder_->createNetworkV2(flags); #else diff --git a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h index 1e340d287629a..58bfcc248f6e8 100755 --- a/src/runtime/contrib/tensorrt/tensorrt_calibrator.h +++ b/src/runtime/contrib/tensorrt/tensorrt_calibrator.h @@ -62,13 +62,13 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { data_sizes_.push_back(binding_sizes); } - int getBatchSize() const override { return batch_size_; } + int getBatchSize() const noexcept override { return batch_size_; } /*! * \brief TensorRT will call this method to get next batch of data to * calibrate with. */ - bool getBatch(void* bindings[], const char* names[], int nbBindings) override { + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override { AllocateBuffersIfNotAllocated(); CHECK_EQ(input_names_.size(), nbBindings); for (size_t i = 0; i < input_names_.size(); ++i) { @@ -83,13 +83,13 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { return (num_batches_calibrated_ < data_.size()); } - const void* readCalibrationCache(size_t& length) override { + const void* readCalibrationCache(size_t& length) noexcept override { if (calibration_cache_.empty()) return nullptr; length = calibration_cache_.size(); return calibration_cache_.data(); } - void writeCalibrationCache(const void* cache, size_t length) override { + void writeCalibrationCache(const void* cache, size_t length) noexcept override { calibration_cache_.assign(static_cast(cache), length); } diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc index 8f0537ad7adca..9c193921f93ba 100644 --- a/src/runtime/logging.cc +++ b/src/runtime/logging.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include namespace tvm { @@ -119,7 +120,7 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const char* filename, int li std::string Backtrace() { BacktraceInfo bt; - bt.max_size = 100; + bt.max_size = 500; if (_bt_state == nullptr) { return ""; } @@ -166,10 +167,102 @@ namespace tvm { namespace runtime { namespace detail { +namespace { +constexpr const char* kSrcPrefix = "/src/"; +// Note: Better would be std::char_traits::length(kSrcPrefix) but it is not +// a constexpr on all compilation targets. +constexpr const size_t kSrcPrefixLength = 5; +constexpr const char* kDefaultKeyword = "DEFAULT"; +} // namespace + +/* static */ +TvmLogDebugSettings TvmLogDebugSettings::ParseSpec(const char* opt_spec) { + TvmLogDebugSettings settings; + if (opt_spec == nullptr) { + // DLOG and VLOG disabled. + return settings; + } + std::string spec(opt_spec); + if (spec.empty() || spec == "0") { + // DLOG and VLOG disabled. + return settings; + } + settings.dlog_enabled_ = true; + if (spec == "1") { + // Legacy specification for enabling just DLOG. + return settings; + } + std::istringstream spec_stream(spec); + while (spec_stream) { + std::string name; + if (!std::getline(spec_stream, name, '=')) { + // Reached end. + break; + } + if (name.empty()) { + LOG(FATAL) << "TVM_LOG_DEBUG ill-formed, empty name"; + return settings; + } + + std::string level; + if (!std::getline(spec_stream, level, ';')) { + LOG(FATAL) << "TVM_LOG_DEBUG ill-formed, expecting level"; + return settings; + } + if (level.empty()) { + LOG(FATAL) << "TVM_LOG_DEBUG ill-formed, empty level"; + return settings; + } + // Parse level, default to 0 if ill-formed which we don't detect. + char* end_of_level = nullptr; + int level_val = static_cast(strtol(level.c_str(), &end_of_level, 10)); + if (end_of_level != level.c_str() + level.size()) { + LOG(FATAL) << "TVM_LOG_DEBUG ill-formed, invalid level"; + return settings; + } + LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name << "' up to level " << level; + settings.vlog_level_map_.emplace(name, level_val); + } + return settings; +} + +bool TvmLogDebugSettings::VerboseEnabledImpl(const std::string& filename, int level) const { + // Canonicalize the filename. + // TODO(mbs): Not Windows friendly. + size_t last_src = filename.rfind(kSrcPrefix, std::string::npos, kSrcPrefixLength); + // Strip anything before the /src/ prefix, on the assumption that will yield the + // TVM project relative filename. If no such prefix fallback to filename without + // canonicalization. + std::string key = + last_src == std::string::npos ? filename : filename.substr(last_src + kSrcPrefixLength); + // Check for exact match. + auto itr = vlog_level_map_.find(key); + if (itr != vlog_level_map_.end()) { + return level <= itr->second; + } + // Check for default. + itr = vlog_level_map_.find(kDefaultKeyword); + if (itr != vlog_level_map_.end()) { + return level <= itr->second; + } + return false; +} + LogFatal::Entry& LogFatal::GetEntry() { static thread_local LogFatal::Entry result; return result; } + +std::string VLogContext::str() const { + std::stringstream result; + for (const auto* entry : context_stack_) { + ICHECK_NOTNULL(entry); + result << entry->str(); + result << ": "; + } + return result.str(); +} + } // namespace detail } // namespace runtime } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 968a4488bbcfe..8db89c59a85d7 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -272,7 +272,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ dtype.code = static_cast(dtype_code); dtype.bits = static_cast(dtype_bits); dtype.lanes = static_cast(dtype_lanes); - Device dev; + tvm::Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); @@ -286,7 +286,7 @@ TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, int ndim = args[1]; ShapeTuple shape(shape_ptr, shape_ptr + ndim); DataType dtype = args[2]; - Device dev = args[3]; + tvm::Device dev = args[3]; Optional mem_scope = args[4]; auto ndarray = NDArray::Empty(shape, dtype, dev, mem_scope); *ret = ndarray; diff --git a/src/support/array.h b/src/support/array.h index 89e17433344b2..95b4f58a2e22f 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -75,9 +75,33 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector * \return The result vector */ template -std::vector AsVector(const Array& vec); +inline std::vector AsVector(const Array& vec); + +/*! + * \brief Convert a std::vector to tvm::runtime::Array + * \tparam TSrc The type of elements in the source vector + * \tparam TDst The type of elements in the result Array + * \return The result vector + */ +template +inline Array AsArray(const std::vector& vec); + +/*! + * \brief Get the shape tuple as array + * \param shape The shape tuple + * \return An array of the shape tuple + */ +inline Array AsArray(const ShapeTuple& shape) { + Array result; + result.reserve(shape->size); + for (ShapeTuple::index_type i : shape) { + result.push_back(Integer(i)); + } + return result; +} /********** Implementation details of AsVector **********/ + namespace details { template @@ -130,11 +154,68 @@ struct AsVectorImpl { }; } // namespace details +/********** Implementation details of AsArray **********/ + +namespace details { + +template +struct AsArrayImpl {}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + return Array(vec.begin(), vec.end()); + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (int x : vec) { + result.push_back(Integer(x)); + } + return result; + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (int64_t x : vec) { + result.push_back(Integer(x)); + } + return result; + } +}; + +template +struct AsArrayImpl { + inline Array operator()(const std::vector& vec) const { + Array result; + result.reserve(vec.size()); + for (double x : vec) { + result.push_back(FloatImm(tvm::DataType::Float(64), x)); + } + return result; + } +}; + +} // namespace details + template inline std::vector AsVector(const Array& vec) { return details::AsVectorImpl()(vec); } +template +inline Array AsArray(const std::vector& vec) { + return details::AsArrayImpl()(vec); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc index 4ced0df6ddf3d..e90967562d163 100644 --- a/src/support/parallel_for.cc +++ b/src/support/parallel_for.cc @@ -67,8 +67,8 @@ void parallel_for(int begin, int end, const std::function& f, int ste res_vec.reserve(run_partitions.size()); for (const auto& run_partition : run_partitions) { std::packaged_task&, const std::function&)> task( - [](const std::vector& run_pattition, const std::function& f) { - for (const auto& i : run_pattition) { + [](const std::vector& run_partition, const std::function& f) { + for (const auto& i : run_partition) { f(i); } }); @@ -93,5 +93,52 @@ void parallel_for(int begin, int end, const std::function& f, int ste } } +void parallel_for_dynamic(int begin, int end, int num_threads, + const std::function& f) { + // Step 1. Sanity checks + if (begin == end) { + return; + } + CHECK_LE(begin, end) << "ValueError: The interval [begin, end) requires `begin <= end`"; + CHECK_GT(num_threads, 0) << "ValueError: `num_threads` should be positive"; + // Step 2. Launch threads + // Step 2.1. Launch worker 1 to worker `num_threads - 1` + std::atomic counter{begin}; + std::vector> futures; + std::vector threads; + futures.reserve(num_threads - 1); + threads.reserve(num_threads - 1); + auto worker = [end, &counter, &f](int thread_id) -> void { + for (int task_id; (task_id = counter++) < end;) { + f(thread_id, task_id); + } + }; + for (int thread_id = 1; thread_id < num_threads; ++thread_id) { + std::packaged_task task(worker); + futures.emplace_back(task.get_future()); + threads.emplace_back(std::move(task), thread_id); + } + // Step 2.2. Launch worker 0 inplace + try { + worker(0); + } catch (const std::exception& e) { + for (auto&& thread : threads) { + thread.join(); + } + LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); + } + // Step 3. Join threads and check exceptions + for (auto&& thread : threads) { + thread.join(); + } + try { + for (auto&& future : futures) { + future.get(); + } + } catch (const std::exception& e) { + LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); + } +} + } // namespace support } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ab96d6e69d143..466f85393b1b9 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -246,8 +246,9 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { + +CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { if (kind < builtin::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -257,57 +258,87 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } switch (kind) { case builtin::kArrAddr: { - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index)); } case builtin::kArrData: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrShape: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(4); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrStrides: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(5); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrNDim: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeCode: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeBits: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeLanes: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrByteOffset: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(6); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceId: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceType: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } default: LOG(FATAL) << "unknown field code"; - return nullptr; + return TypedPointer(); } } @@ -373,7 +404,10 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -440,9 +474,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // $xxx_compute_ functions are not global. They should be marked as static (via InternalLinkage) // to call them correctly on MIPS platform (CALL16 issue) // Linkage ld Error: CALL16 reloc at 0x290 not against global symbol - llvm::Function* fcompute = llvm::Function::Create( - ftype, llvm::Function::InternalLinkage, - op->value.as()->value.operator llvm::StringRef(), module_.get()); + const StringImmNode* value = op->value.as(); + ICHECK(value != nullptr); + llvm::Function* fcompute = + llvm::Function::Create(ftype, llvm::Function::InternalLinkage, + value->value.operator llvm::StringRef(), module_.get()); BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); // setup compute function. std::unordered_map new_vmap; @@ -473,22 +509,26 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } #endif } + auto new_analyzer = std::make_unique(); std::swap(function_, fcompute); - std::swap(new_vmap, var_map_); + std::swap(analyzer_, new_analyzer); + std::swap(var_map_, new_vmap); BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); // swap the var map back, now we are back on track. - std::swap(new_vmap, var_map_); + std::swap(var_map_, new_vmap); + std::swap(analyzer_, new_analyzer); std::swap(function_, fcompute); builder_->SetInsertPoint(compute_call_end); } -llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* num_bytes) { +CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, + uint64_t* num_bytes) { if (vfields.size() == 0) { *num_bytes = 0U; - return llvm::Constant::getNullValue(t_void_p_); + return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_)); } std::vector fields; for (Var v : vfields) { @@ -496,23 +536,24 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu ICHECK(it != var_map_.end()); fields.push_back(it->second->getType()); } - llvm::StructType* tcdata = llvm::StructType::create(fields); - llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); + llvm::StructType* ctype = llvm::StructType::create(fields); + llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { builder_->CreateStore(var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)})); } - *num_bytes = data_layout_->getTypeAllocSize( - llvm::cast(cdata->getType())->getElementType()); - return cdata; + *num_bytes = data_layout_->getTypeAllocSize(ctype); + return TypedPointer(ctype, cvalue); } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, +void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { - (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); + llvm::Type* field_type = cdata.type->getStructElementType(i); + llvm::Value* field_addr = + builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)}); + (*vmap)[vfields[i].get()] = builder_->CreateLoad(field_type, field_addr); } } @@ -525,21 +566,22 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( - launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + launch_callee, + {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); llvm::Value* penv = &(*it++); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); @@ -548,16 +590,20 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = builder_->CreateLoad( + t_int32_, + builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; + auto new_analyzer = std::make_unique(); std::swap(function_, f); std::swap(parallel_env_, par_env); + std::swap(analyzer_, new_analyzer); std::swap(var_map_, new_vmap); this->VisitStmt(body); builder_->CreateRet(ConstInt32(0)); // swap the var map back, now we are back on track. std::swap(var_map_, new_vmap); + std::swap(analyzer_, new_analyzer); std::swap(parallel_env_, par_env); std::swap(function_, f); ICHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; @@ -592,24 +638,27 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( - finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); ICHECK(parallel_env_.penv == nullptr); + auto new_analyzer = std::make_unique(); std::swap(function_, f); + std::swap(analyzer_, new_analyzer); std::swap(var_map_, new_vmap); this->VisitStmt(body); builder_->CreateRet(ConstInt32(0)); // swap the var map back, now we are back on track. std::swap(var_map_, new_vmap); + std::swap(analyzer_, new_analyzer); std::swap(function_, f); builder_->SetInsertPoint(init_end); } @@ -644,7 +693,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -656,8 +707,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -671,7 +725,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -686,10 +743,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, - const int64_t begin, const int64_t end) { - using llvm::BasicBlock; +CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& args, + const DataType& r_type, + const int64_t begin, const int64_t end) { + PackedCall pc; std::string func_name = args[0].as()->value; llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function @@ -698,70 +755,80 @@ llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm:: llvm::Value* stack_value = MakeValue(args[1]); llvm::Value* stack_tcode = MakeValue(args[2]); llvm::Value* arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); + llvm::Value* call = builder_->CreateCall( + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::BasicBlock* end_block = CheckCallSuccess(call); + + // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - *rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else - *rvalue = builder_->CreateAlignedLoad(load_ptr, 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif - *rvalue = CreateCast(r_api_type, r_type, *rvalue); - return end_block; + pc.ret_value = CreateCast(r_api_type, r_type, rvalue); + + // Load the return type code. +#if TVM_LLVM_VERSION >= 110 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); +#else + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); +#endif + + pc.end_block = end_block; + return pc; } llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 5U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); - return rvalue; + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); + return pc.ret_value; } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { - using llvm::BasicBlock; ICHECK_EQ(op->args.size(), 6U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - BasicBlock* end_block = - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); -#if TVM_LLVM_VERSION >= 110 - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); -#else - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); -#endif + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = - builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); - phi_rvalue->addIncoming(rvalue, update_block); - phi_rvalue->addIncoming(traced_value, end_block); + phi_rvalue->addIncoming(pc.ret_value, update_block); + phi_rvalue->addIncoming(traced_value, pc.end_block); return phi_rvalue; } @@ -868,24 +935,24 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = - this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); + TypedPointer ref = + CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } else { - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); ICHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2U); @@ -941,7 +1008,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::coproc_uop_scope) { - this->CreateStaticInit(op->value.as()->value, op->body); + const StringImmNode* value = op->value.as(); + ICHECK(value != nullptr); + this->CreateStaticInit(value->value, op->body); } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (tir::attr::IsPragmaKey(op->attr_key)) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index d08bd639e1311..402189eb374d1 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -105,13 +105,17 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value* cdata, const Array& fields, + TypedPointer PackClosureData(const Array& fields, uint64_t* num_bytes); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); + void UnpackClosureData(TypedPointer cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, + struct PackedCall { + llvm::Value* ret_value; + llvm::Value* ret_tcode; + llvm::BasicBlock* end_block; + }; + PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index d9d0d1f3d6a45..bffb620d49f90 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -75,7 +75,7 @@ class CodeGenHexagon final : public CodeGenLLVM { llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; private: - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); // Check if the call to packed function is successful // if not directly finalize function and pass on return code. @@ -97,8 +97,12 @@ class CodeGenHexagon final : public CodeGenLLVM { std::unordered_map func_handle_map_; // Make packed call. - llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, + struct PackedCall { + llvm::Value* ret_value; + llvm::Value* ret_tcode; + llvm::BasicBlock* end_block; + }; + PackedCall MakeCallPackedLowered(const Array& args, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); @@ -251,7 +255,10 @@ llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::st llvm::Value* CodeGenHexagon::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -296,11 +303,11 @@ llvm::Value* CodeGenHexagon::RuntimeTVMAPISetLastError() { return GetContextPtr(gv_tvm_api_set_last_error_); } -llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array& args, llvm::Value** rvalue, - llvm::Value** ret_tcode, const DataType& r_type, - const int64_t begin, const int64_t end) { - using llvm::BasicBlock; - // using namespace tir; +CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array& args, + const DataType& r_type, + const int64_t begin, + const int64_t end) { + PackedCall pc; std::string func_name = args[0].as()->value; llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function @@ -309,29 +316,48 @@ llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array& args, ll llvm::Value* stack_value = MakeValue(args[1]); llvm::Value* stack_tcode = MakeValue(args[2]); llvm::Value* arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); + llvm::Value* call = builder_->CreateCall( + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); + llvm::BasicBlock* end_block = CheckCallSuccess(call); + + // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - *rvalue = builder_->CreateAlignedLoad( - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), - llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else - *rvalue = builder_->CreateAlignedLoad( - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()), 8); + llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif - *rvalue = CreateCast(r_api_type, r_type, *rvalue); - return end_block; + pc.ret_value = CreateCast(r_api_type, r_type, rvalue); + + // Load the return type code. +#if TVM_LLVM_VERSION >= 110 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); +#else + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); +#endif + + pc.end_block = end_block; + return pc; } llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { @@ -364,7 +390,9 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -376,8 +404,11 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -391,7 +422,10 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -417,44 +451,34 @@ llvm::Value* CodeGenHexagon::CreateCallPacked(const CallNode* op) { } ICHECK_EQ(op->args.size(), 5U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); - return rvalue; + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); + return pc.ret_value; } llvm::Value* CodeGenHexagon::CreateCallTracePacked(const CallNode* op) { - using llvm::BasicBlock; ICHECK_EQ(op->args.size(), 6U); - llvm::Value* rvalue = nullptr; - llvm::Value* ret_tcode = nullptr; - BasicBlock* end_block = - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); -#if TVM_LLVM_VERSION >= 110 - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); -#else - llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); -#endif + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = - builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + builder_->CreateICmpNE(pc.ret_tcode, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); - phi_rvalue->addIncoming(rvalue, update_block); - phi_rvalue->addIncoming(traced_value, end_block); + phi_rvalue->addIncoming(pc.ret_value, update_block); + phi_rvalue->addIncoming(traced_value, pc.end_block); return phi_rvalue; } @@ -508,23 +532,23 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3); int kind = op->args[2].as()->value; - llvm::Value* ref = + TypedPointer ref = CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4); int kind = op->args[2].as()->value; ICHECK(kind != builtin::kArrAddr); - llvm::Value* ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); llvm::Value* value = MakeValue(op->args[3]); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2); @@ -543,8 +567,8 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { return CodeGenLLVM::CreateIntrinsic(op); } -llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { +CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { static const std::map field_index = { {builtin::kArrData, 0}, {builtin::kArrDeviceType, 1}, {builtin::kArrDeviceId, 1}, {builtin::kArrNDim, 2}, {builtin::kArrTypeCode, 3}, {builtin::kArrTypeBits, 3}, @@ -575,12 +599,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll uint64_t byte_offset; kArrByteOffset } DLTensor; */ - llvm::Value* base_gep = builder_->CreateInBoundsGEP(buf, index, "base_gep"); + llvm::Value* base_gep = builder_->CreateInBoundsGEP(t_tvm_array_, buf, index, "base_gep"); if (kind == builtin::kArrAddr) { - return base_gep; + return TypedPointer(t_void_p_, base_gep); } llvm::Value* field_gep = builder_->CreateInBoundsGEP( - base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + t_tvm_array_, base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + llvm::Type* field_type = t_tvm_array_->getStructElementType(field_index.at(kind)); switch (kind) { // These fields have no sub-fields. case builtin::kArrData: @@ -588,10 +613,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll case builtin::kArrShape: case builtin::kArrStrides: case builtin::kArrByteOffset: - return field_gep; + return TypedPointer(field_type, field_gep); } - return builder_->CreateInBoundsGEP( - field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, "subfield_gep"); + llvm::Value* subfield_gep = builder_->CreateInBoundsGEP( + field_type, field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, + "subfield_gep"); + llvm::Type* subfield_type = field_type->getStructElementType(subfield_index.at(kind)); + return TypedPointer(subfield_type, subfield_gep); } if (kind == builtin::kTVMValueContent) { @@ -609,20 +637,20 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_void_p_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } assert(!"Unknown kind"); - return nullptr; + return TypedPointer(); } namespace { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6aabdc1bd804f..12fbf2c3e42c3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -473,9 +473,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); // Extract the underlying type of the allocated buffer. - llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType(); - if (buf_type->isPointerTy()) { - buf_type = buf_type->getPointerElementType(); + DataType dtype = buffer->dtype; + if (buffer->type_annotation.defined()) { + Type element_type = Downcast(buffer->type_annotation)->element_type; + if (auto* ptype = element_type.as()) { + dtype = ptype->dtype; + } + } + llvm::Type* buf_type = DTypeToLLVMType(dtype); + if (!buf_type) { + buf_type = t_void_p_; } std::string tmp; @@ -737,14 +744,17 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { +CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, + llvm::Value* index) { llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); ICHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); - if (btype != ptype) { - buffer = builder_->CreatePointerCast(buffer, ptype); + llvm::Type* llvm_type = DTypeToLLVMType(t); + llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); + if (btype != ttype) { + buffer = builder_->CreatePointerCast(buffer, ttype); } - return builder_->CreateInBoundsGEP(buffer, index); + llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index); + return TypedPointer(llvm_type, ptr); } llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { @@ -861,10 +871,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " -#if TVM_LLVM_VERSION <= 130 - << llvm::Intrinsic::getName(id, {}); +#if TVM_LLVM_VERSION >= 130 + << llvm::Intrinsic::getBaseName(id).str(); #else - << llvm::Intrinsic::getName(id, return_type, {}); + << llvm::Intrinsic::getName(id, {}); #endif return builder_->CreateCall(f, arg_value); } else if (op->op.same_as(builtin::bitwise_and())) { @@ -888,18 +898,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); ICHECK(op->args.size() == 1 && l); - const RampNode* r = l->index.as(); - llvm::Value* ptr; - unsigned addrspace; - if (!r) { - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); - } else { + TypedPointer buffer_ptr; + if (const RampNode* r = l->index.as()) { PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + } else { + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); } - return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); + unsigned addrspace = + llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); + return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->op.same_as(builtin::isnullptr())) { @@ -1154,29 +1162,40 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; @@ -1187,11 +1206,15 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr()); @@ -1271,30 +1294,36 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; } else { // vector store - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; @@ -1305,13 +1334,14 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { // scalarized store. int basic_align = t.bits() / 8; auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore( + builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); }; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index a4f007aeebed0..177b530563545 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -181,6 +181,15 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: + /*! + * \brief Address and type pair to assist in handling opaque pointers. + */ + struct TypedPointer { + TypedPointer() = default; + TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {} + llvm::Type* type = nullptr; /*!< Type of the value pointed to. */ + llvm::Value* addr = nullptr; /*!< Address of the value. */ + }; /*! \brief The storage information */ struct StorageInfo { /*! \brief The alignment of allocation */ @@ -301,7 +310,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 7abff36a3ddb9..d93a7fde639a4 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -59,8 +59,6 @@ class InferTextureAccess : public StmtExprVisitor { var_access_map_[op->args[0].as()] |= kReadAccess; } else if (op->op.same_as(builtin::texture2d_store())) { var_access_map_[op->args[0].as()] |= kWriteAccess; - } else { - StmtExprVisitor::VisitExpr_(op); } StmtExprVisitor::VisitExpr_(op); } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 0382b8071de7a..b6c41b958c310 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -172,8 +172,9 @@ std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; - DLOG(INFO) << "verifying memory for target '" << target.value()->str() << "' for primitive\n" - << PrettyPrint(func); + VLOG(1) << "verifying memory for target '" << target.value()->str() + << "' for primitive:" << std::endl + << PrettyPrint(func); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 5db131c44f2ab..d08bef2ab91ad 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -273,6 +273,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value, op->span); } + ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types."; return tir::Cast(t, value, span); } else { if (value.dtype().lanes() == 1) { diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 07af73ebabb6e..93eba520f9d13 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -220,9 +220,7 @@ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState se } support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { - // In order for reproducibility, we computer the new seed using RNG's random state and a different - // set of parameters. Note that both 32767 and 1999999973 are prime numbers. - return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; + return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 05eefaca8a11f..8d8acd2693f45 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,6 +26,14 @@ namespace tvm { namespace tir { /******** Schedule: Sampling ********/ +/*! + * \brief Sample a random integer from a given range. + * \param min_inclusive The minimum value of the range, inclusive. + * \param max_exclusive The maximum value of the range, exclusive. + * \return The random integer sampled in the given range. + */ +TVM_DLL int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive); /*! * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update @@ -72,6 +80,7 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * 1) The loops can't have annotations or thread bindings. * 2) The inner loop must be the only child of the outer loop. * 3) All loops must start with 0. + * 4) The domain of a loop to be fused cannot depend on another loop to be fused. * \param self The state of the schedule * \param loop_srefs An array of srefs to the loops to be fused * \return The sref to the fused loop diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 95c92aa0a3229..7b9ac488b8b93 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -358,17 +358,26 @@ class LoopsNotAChainError : public ScheduleError { class DependentLoopError : public ScheduleError { public: - explicit DependentLoopError(IRModule mod, For loop, String inner_var) - : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + enum class PrimitiveKind { kFuse, kReorder }; + explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind) + : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {} String FastErrorString() const final { - return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " - "in the new order"; + if (kind_ == PrimitiveKind::kReorder) { + return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " + "in the new order"; + } else { + return "ScheduleError: A loop's `extent` is dependent on another loop"; + } } String DetailRenderTemplate() const final { - return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + - " in the new order"; + if (kind_ == PrimitiveKind::kReorder) { + return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + + " in the new order"; + } else { + return "A loop {0}'s `extent` is dependent on another loop " + inner_var_; + } } IRModule mod() const final { return mod_; } @@ -377,6 +386,7 @@ class DependentLoopError : public ScheduleError { IRModule mod_; For loop_; String inner_var_; + PrimitiveKind kind_; }; Array Split(ScheduleState self, const StmtSRef& loop_sref, @@ -450,6 +460,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { StmtSRef outer_loop_sref{nullptr}; const ForNode* outer_loop = nullptr; arith::Analyzer analyzer; + std::unordered_set outer_loop_vars; // Step 1. check correctness for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); @@ -469,6 +480,19 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { if (!analyzer.CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } + const VarNode* used_var = nullptr; + auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) { + if (outer_loop_vars.count(var)) { + used_var = var; + return true; + } + return false; + }; + if (UsesVar(loop->extent, f_contain)) { + throw DependentLoopError(self->mod, GetRef(loop), used_var->name_hint, + DependentLoopError::PrimitiveKind::kFuse); + } + outer_loop_vars.insert(loop->loop_var.get()); loops.push_back(loop); } // Step 2. Create fused loop var and replace the original loop vars @@ -651,7 +675,8 @@ For ConstructNewLoopChain(const ScheduleState& self, std::vectormin, f_contain) || UsesVar(copy->extent, f_contain)) { - throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); + throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint, + DependentLoopError::PrimitiveKind::kReorder); } inner_vars.insert(copy->loop_var.get()); new_loop = For(std::move(n)); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 8843ac6131794..6ac6226118cde 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -24,6 +24,18 @@ namespace tvm { namespace tir { +int SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int min_inclusive, + int max_exclusive) { + CHECK(min_inclusive < max_exclusive) + << "ValueError: max_exclusive must be greater than min_inclusive."; + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + support::LinearCongruentialEngine rand_(rand_state); + std::uniform_int_distribution dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 97f5b6f90a704..c4b83e05706dd 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -98,7 +98,13 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { // partition const loop when sets partition_const_loop_ if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) { + // always treat var with hint to be partitioned const VarNode* var = op->loop_var.get(); + if (partition_hint_vars.count(var)) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { @@ -117,6 +123,12 @@ class CandidateSelector final : public StmtExprVisitor { Var var = iv->var; runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) { + // always treat var with hint to be partitioned + if (partition_hint_vars.count(var.get())) { + candidates.insert(GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + return; + } record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { @@ -125,6 +137,15 @@ class CandidateSelector final : public StmtExprVisitor { record_.erase(var.get()); return; } + } else if (op->attr_key == attr::pragma_loop_partition_hint) { + const VarNode* var = nullptr; + if (op->node->IsInstance()) { + var = op->node.as(); + } else if (op->node->IsInstance()) { + var = op->node.as()->var.get(); + } + ICHECK(var); + partition_hint_vars.insert(var); } StmtExprVisitor::VisitStmt_(op); } @@ -162,6 +183,7 @@ class CandidateSelector final : public StmtExprVisitor { } std::unordered_set candidates; + std::unordered_set partition_hint_vars; private: bool in_likely_{false}; @@ -170,15 +192,28 @@ class CandidateSelector final : public StmtExprVisitor { std::unordered_map record_; }; +// Finder try best to find partitions for hinted vars +#define DEFINE_PARTITION_FINDER_VISIT_CMP_OP(OpNodeT) \ + void VisitExpr_(const OpNodeT* op) final { \ + if (has_partition_hint_) { \ + DeduceCondition(GetRef(op)); \ + return; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } + // Populate partitions data structure, i.e., for a specific variable, -// find an interval in which each condition -// (currently, "likely" conditions) has fixed true or false value +// find an interval in which each condition has fixed true or false value class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + const std::unordered_map& relax_map, + bool has_partition_hint) + : current_var_(current_var), + has_partition_hint_(has_partition_hint), + hint_map_(hint_map), + relax_map_(relax_map) { for (const auto& kv : hint_map) { out_vars_.insert(kv.first); } @@ -218,33 +253,43 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { - PrimExpr cond = op->args[0]; - if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { - // For cond, find out the interval, if exists, in which we can prove that cond is - // true. Also find the interval, if exists, in which we can prove that cond is - // false. - IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is true within interval - partitions[{cond, true}] = interval; - } - PrimExpr inverse_cond = InverseCond(cond); - if (inverse_cond.defined()) { - IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); - if (!interval.IsNothing()) { - // cond is false within interval - partitions[{cond, false}] = interval; - } - } - } + DeduceCondition(op->args[0]); } else { StmtExprVisitor::VisitExpr_(op); } } + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(GTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LENode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(LTNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(EQNode); + DEFINE_PARTITION_FINDER_VISIT_CMP_OP(NENode); + Partition partitions; private: + void DeduceCondition(const PrimExpr& cond) { + // For cond, find out the interval, if exists, in which we can prove that cond is + // true. Also find the interval, if exists, in which we can prove that cond is + // false. + if (UsesVar(cond, [this](const VarNode* var) { return var == current_var_.get(); })) { + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is true within interval + partitions[{cond, true}] = interval; + } + PrimExpr inverse_cond = InverseCond(cond); + if (inverse_cond.defined()) { + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + if (!interval.IsNothing()) { + // cond is false within interval + partitions[{cond, false}] = interval; + } + } + } + } + PrimExpr InverseCond(const PrimExpr& cond) { PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { @@ -270,6 +315,7 @@ class PartitionFinder : public StmtExprVisitor { } Var current_var_; + bool has_partition_hint_; std::unordered_set out_vars_; std::unordered_map hint_map_; std::unordered_map relax_map_; @@ -472,7 +518,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim // include hint of var. hint_map_.insert({var.get(), IntSet::Interval(min, max)}); - PartitionFinder finder(var, hint_map_, relax_map_); + bool has_partition_hint_ = selector.partition_hint_vars.count(var.get()); + PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_); finder(body); hint_map_.erase(var.get()); @@ -601,7 +648,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b } } -class RemoveLikelyTags : public StmtExprMutator { +class RemoveLikelyTagsAndHints : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { @@ -611,12 +658,19 @@ class RemoveLikelyTags : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::pragma_loop_partition_hint) { + return VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } }; Stmt LoopPartition(Stmt stmt, bool partition_const_loop, bool no_unroll_loop_with_extent_one) { stmt = LoopPartitioner(partition_const_loop, no_unroll_loop_with_extent_one) .VisitAndMutate(std::move(stmt)); - stmt = RemoveLikelyTags()(std::move(stmt)); + stmt = RemoveLikelyTagsAndHints()(std::move(stmt)); return stmt; } diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc index 2f8fbe0ea6e79..b48749c4c77c1 100644 --- a/src/tir/transforms/lower_match_buffer.cc +++ b/src/tir/transforms/lower_match_buffer.cc @@ -188,25 +188,25 @@ class MatchBufferLower : public StmtExprMutator { Load load = Downcast(source_buffer.vload(indices, source_buffer->dtype)); Bind(buffer->elem_offset, load->index, buffer->name + ".elem_offset"); CHECK(analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) - << "The source elem_offset " << buffer->elem_offset - << " does not satisfy the offset_factor " << buffer->offset_factor << "."; + << "The source elem_offset " << load->index << " does not satisfy the offset_factor " + << buffer->offset_factor << "."; } // Step 2.3. Check and update strides // Check if target buffer strides are defined + ICHECK(source->region.size() >= buffer->shape.size()); + int offset = source->region.size() - buffer->shape.size(); if (!buffer->strides.empty()) { ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); PrimExpr stride = make_const(DataType::Int(32), 1); for (size_t i = buffer->shape.size(); i > 0; --i) { - const PrimExpr& shape = source_buffer->shape[i - 1]; + const PrimExpr& shape = source_buffer->shape[i - 1 + offset]; Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1)); stride *= shape; } } // Step 2.4. Check and update shape - ICHECK(source->region.size() >= buffer->shape.size()); - size_t offset = source->region.size() - buffer->shape.size(); for (size_t i = 0; i < buffer->shape.size(); ++i) { const Range& range = source->region[i + offset]; Bind(buffer->shape[i], range->extent, buffer->name + ".shape_" + std::to_string(i)); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 99d71ebe15bda..062d67eef1656 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -113,9 +113,14 @@ class BuiltinLower : public StmtExprMutator { op = stmt.as(); // Get constant allocation bound. int64_t nbytes = GetVectorBytes(op->dtype); + // If the buffers are for CPU and have global scope, + // and less than runtime::kMaxStackAlloca heuristic + // they are not serviced with TVMBackendWorkspaceAlloc calls + // to be placed on stack. if (device_type_.defined()) { if (const auto* dev_type = device_type_.as()) { - if (dev_type->value == kDLCPU) { + auto storage_scope = Downcast(op->buffer_var->type_annotation)->storage_scope; + if (dev_type->value == kDLCPU && storage_scope == "global") { int32_t constant_size = op->constant_allocation_size(); if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { return stmt; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 592a6a33375ee..409b7c2629548 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -478,6 +478,11 @@ class StoragePlanRewriter : public StmtExprMutator { uint64_t bits_offset{0}; }; + // Checks whether the storage_scope is especially tagged for a specific memory. + bool IsSpecialTaggedMemory(const StorageScope& scope) { + return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace"; + } + // Alllocate entry of node. // Event entry in liveness analysis struct EventEntry { @@ -516,7 +521,7 @@ class StoragePlanRewriter : public StmtExprMutator { // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; - if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { + if (IsSpecialTaggedMemory(e->scope)) { ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { @@ -550,7 +555,7 @@ class StoragePlanRewriter : public StmtExprMutator { make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); - if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { + if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) @@ -591,7 +596,7 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = analyzer_.Simplify(combo_size); e->new_alloc = Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); - if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") { + if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits) diff --git a/tests/cpp/name_transforms_test.cc b/tests/cpp/name_transforms_test.cc new file mode 100644 index 0000000000000..9fc52e09dea80 --- /dev/null +++ b/tests/cpp/name_transforms_test.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../src/relay/backend/name_transforms.h" + +#include +#include + +using namespace tvm::relay::backend; +using namespace tvm::runtime; + +TEST(NameTransforms, ToCFunctionStyle) { + ASSERT_EQ(ToCFunctionStyle("TVM_Woof"), "TVMWoof"); + ASSERT_EQ(ToCFunctionStyle("TVM_woof"), "TVMWoof"); + ASSERT_EQ(ToCFunctionStyle("TVM_woof_woof"), "TVMWoofWoof"); + ASSERT_EQ(ToCFunctionStyle("TVMGen_woof_woof"), "TVMGenWoofWoof"); + EXPECT_THROW(ToCVariableStyle("Cake_Bakery"), InternalError); // Incorrect prefix + EXPECT_THROW(ToCFunctionStyle(""), InternalError); +} + +TEST(NameTransforms, ToCVariableStyle) { + ASSERT_EQ(ToCVariableStyle("TVM_Woof"), "tvm_woof"); + ASSERT_EQ(ToCVariableStyle("TVM_woof"), "tvm_woof"); + ASSERT_EQ(ToCVariableStyle("TVM_woof_Woof"), "tvm_woof_woof"); + EXPECT_THROW(ToCVariableStyle("Cake_Bakery"), InternalError); // Incorrect prefix + EXPECT_THROW(ToCVariableStyle(""), InternalError); +} + +TEST(NameTransforms, PrefixName) { + ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof"); + ASSERT_EQ(PrefixName({"woof"}), "TVM_woof"); + ASSERT_EQ(PrefixName({"woof", "moo"}), "TVM_woof_moo"); + EXPECT_THROW(PrefixName({}), InternalError); + EXPECT_THROW(PrefixName({""}), InternalError); +} + +TEST(NameTransforms, PrefixGeneratedName) { + ASSERT_EQ(PrefixGeneratedName({"Woof"}), "TVMGen_Woof"); + ASSERT_EQ(PrefixGeneratedName({"woof"}), "TVMGen_woof"); + ASSERT_EQ(PrefixGeneratedName({"woof", "moo"}), "TVMGen_woof_moo"); + EXPECT_THROW(PrefixGeneratedName({}), InternalError); + EXPECT_THROW(PrefixGeneratedName({""}), InternalError); +} + +TEST(NameTransforms, CombineNames) { + ASSERT_EQ(CombineNames({"woof"}), "woof"); + ASSERT_EQ(CombineNames({"Woof", "woof"}), "Woof_woof"); + ASSERT_EQ(CombineNames({"Woof", "woof", "woof"}), "Woof_woof_woof"); + ASSERT_EQ(CombineNames({"Woof", "moo", "t"}), "Woof_moo_t"); + + EXPECT_THROW(CombineNames({}), InternalError); + EXPECT_THROW(CombineNames({""}), InternalError); + EXPECT_THROW(CombineNames({"Woof", ""}), InternalError); + EXPECT_THROW(CombineNames({"", "Woof"}), InternalError); +} + +TEST(NameTransforms, SanitizeName) { + ASSERT_EQ(SanitizeName("+_+ "), "_"); + ASSERT_EQ(SanitizeName("input+"), "input_"); + ASSERT_EQ(SanitizeName("input-"), "input_"); + ASSERT_EQ(SanitizeName("input++"), "input_"); + ASSERT_EQ(SanitizeName("woof:1"), "woof_1"); + EXPECT_THROW(SanitizeName(""), InternalError); +} + +TEST(NameTransforms, CombinedLogic) { + ASSERT_EQ(ToCFunctionStyle(PrefixName({"Device", "target", "Invoke"})), "TVMDeviceTargetInvoke"); + ASSERT_EQ(ToCFunctionStyle(PrefixGeneratedName({"model", "Run"})), "TVMGenModelRun"); + ASSERT_EQ(ToCVariableStyle(PrefixName({"Device", "target", "t"})), "tvm_device_target_t"); + ASSERT_EQ(ToCVariableStyle(PrefixGeneratedName({"model", "Devices"})), "tvmgen_model_devices"); +} diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc index c1e568e4cede8..e32fd32012a68 100644 --- a/tests/cpp/parallel_for_test.cc +++ b/tests/cpp/parallel_for_test.cc @@ -22,6 +22,7 @@ #include #include +#include #include TEST(ParallelFor, Basic) { @@ -90,7 +91,7 @@ TEST(ParallelFor, NestedWithNormalForLoop) { } } -TEST(Parallelfor, NestedWithParallelFor) { +TEST(ParallelFor, NestedWithParallelFor) { // Currently do not support using nested parallel_for using tvm::support::parallel_for; @@ -118,3 +119,42 @@ TEST(ParallelFor, Exception) { } ICHECK(exception); } + +TEST(ParallelForDynamic, Basic) { + using tvm::support::parallel_for_dynamic; + int a[1000]; + int num_threads = std::thread::hardware_concurrency(); + parallel_for_dynamic(0, 1000, num_threads, [&a](int thread_id, int i) { a[i] = i; }); + for (int i = 0; i < 1000; i++) { + ICHECK_EQ(a[i], i); + } +} + +TEST(ParallelForDynamic, ExceptionOnMain) { + using tvm::support::parallel_for_dynamic; + int num_threads = 1; + bool exception = false; + try { + parallel_for_dynamic(0, 10, num_threads, [](int thread_id, int task_id) { + if (thread_id == 0) { + LOG(FATAL) << "Error"; + } + }); + } catch (const std::exception& e) { + exception = true; + } + ICHECK(exception); +} + +TEST(ParallelForDynamic, ExceptionOnArbitrary) { + using tvm::support::parallel_for_dynamic; + int num_threads = 3; + bool exception = false; + try { + parallel_for_dynamic(0, 100, num_threads, + [](int thread_id, int task_id) { LOG(FATAL) << "Error"; }); + } catch (const std::exception& e) { + exception = true; + } + ICHECK(exception); +} diff --git a/tests/cpp/relay/relay/transforms/device_domains_test.cc b/tests/cpp/relay/relay/transforms/device_domains_test.cc new file mode 100644 index 0000000000000..8f263c3b3273b --- /dev/null +++ b/tests/cpp/relay/relay/transforms/device_domains_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Just a smoke test for the device planner's unification domain, mostly to tease out how we'd + * like to organize our cpp unit tests for functionality that's not obviously a Pass or should + * be exposed via FFI. + */ + +// TODO(mbs): Revisit cpp unit test layout or setup include dir at root of src/ +#include "../../../src/relay/transforms/device_domains.h" + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { +namespace { + +IRModule TestModule() { + return InferType()(tvm::parser::ParseModule("test", R"( + #[version = "0.0.5"] + def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { + add(%x, %y) + } + )")); +} + +TEST(DeviceDomains, SmokeTest) { + DeviceDomains domains; + IRModule mod = TestModule(); + Function f = Downcast(mod->Lookup("f")); + + DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); + DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); + DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); + DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + std::vector arg_and_results; + arg_and_results.push_back(x_domain); + arg_and_results.push_back(y_domain); + arg_and_results.push_back(result_domain); + DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); + domains.Unify(actual_add_domain, implied_add_domain); + domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + + EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); + EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); +} + +} // namespace +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/runtime/logging_test.cc b/tests/cpp/runtime/logging_test.cc new file mode 100644 index 0000000000000..a4e6c01444e6f --- /dev/null +++ b/tests/cpp/runtime/logging_test.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace runtime { +namespace detail { +namespace { + +TEST(TvmLogDebugSettings, Disabled) { + TvmLogDebugSettings settings = TvmLogDebugSettings::ParseSpec(nullptr); + EXPECT_FALSE(settings.dlog_enabled()); + + settings = TvmLogDebugSettings::ParseSpec(""); + EXPECT_FALSE(settings.dlog_enabled()); + + settings = TvmLogDebugSettings::ParseSpec("0"); + EXPECT_FALSE(settings.dlog_enabled()); +} + +TEST(TvmLogDebugSettings, DlogOnly) { + TvmLogDebugSettings settings = TvmLogDebugSettings::ParseSpec("1"); + EXPECT_TRUE(settings.dlog_enabled()); + EXPECT_FALSE(settings.VerboseEnabled("my/filesytem/src/foo/bar.cc", 0)); +} + +TEST(TvmLogDebugSettings, VLogEnabledDefault) { + TvmLogDebugSettings settings = TvmLogDebugSettings::ParseSpec("DEFAULT=3"); + EXPECT_TRUE(settings.dlog_enabled()); + EXPECT_TRUE(settings.VerboseEnabled("my/filesytem/src/foo/bar.cc", 3)); + EXPECT_FALSE(settings.VerboseEnabled("my/filesytem/src/foo/bar.cc", 4)); +} + +TEST(TvmLogDebugSettings, VLogEnabledComplex) { + TvmLogDebugSettings settings = + TvmLogDebugSettings::ParseSpec("foo/bar.cc=3;baz.cc=-1;DEFAULT=2;another/file.cc=4"); + EXPECT_TRUE(settings.dlog_enabled()); + EXPECT_TRUE(settings.VerboseEnabled("my/filesystem/src/foo/bar.cc", 3)); + EXPECT_FALSE(settings.VerboseEnabled("my/filesystem/src/foo/bar.cc", 4)); + EXPECT_TRUE(settings.VerboseEnabled("my/filesystem/src/foo/other.cc", 2)); + EXPECT_FALSE(settings.VerboseEnabled("my/filesystem/src/foo/other.cc", 3)); + EXPECT_FALSE(settings.VerboseEnabled("my/filesystem/src/baz.cc", 0)); +} + +TEST(TvmLogDebugSettings, IllFormed) { + EXPECT_THROW(TvmLogDebugSettings::ParseSpec("foo/bar.cc=bogus;"), InternalError); +} + +} // namespace +} // namespace detail +} // namespace runtime +} // namespace tvm diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index 0677292371aef..1b45ac783c29a 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -68,6 +68,7 @@ "plist", "xcworkspacedata", "storyboard", + "xcscheme", # hw/chisel "sbt", "properties", diff --git a/tests/lint/flake8.sh b/tests/lint/flake8.sh index 43ade61c78933..b16d97fce3612 100755 --- a/tests/lint/flake8.sh +++ b/tests/lint/flake8.sh @@ -16,5 +16,5 @@ # specific language governing permissions and limitations # under the License. -# Disabled until docker images are rebuilt -# python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics + +python3 -m flake8 . --count --select=E9,F63,F7 --show-source --statistics diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 7c19b62ac63dc..177ca8aa269e8 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -20,44 +20,16 @@ import pytest -from tvm.micro import project -import tvm.contrib.utils -import tvm.target.target +import test_utils -TEMPLATE_PROJECT_DIR = ( - pathlib.Path(__file__).parent - / ".." - / ".." - / ".." - / "apps" - / "microtvm" - / "zephyr" - / "template_project" -).resolve() - - -def zephyr_boards() -> dict: - """Returns a dict mapping board to target model""" - template = project.TemplateProject.from_directory(TEMPLATE_PROJECT_DIR) - project_options = template.info()["project_options"] - for option in project_options: - if option["name"] == "zephyr_board": - boards = option["choices"] - if option["name"] == "zephyr_model": - models = option["choices"] - - arduino_boards = {boards[i]: models[i] for i in range(len(boards))} - return arduino_boards - - -ZEPHYR_BOARDS = zephyr_boards() +from tvm.contrib.utils import tempdir def pytest_addoption(parser): parser.addoption( "--zephyr-board", required=True, - choices=ZEPHYR_BOARDS.keys(), + choices=test_utils.ZEPHYR_BOARDS.keys(), help=("Zephyr board for test."), ) parser.addoption( @@ -104,4 +76,4 @@ def temp_dir(board): if not os.path.exists(board_workspace.parent): os.makedirs(board_workspace.parent) - return tvm.contrib.utils.tempdir(board_workspace) + return tempdir(board_workspace) diff --git a/tests/micro/zephyr/test_utils.py b/tests/micro/zephyr/test_utils.py new file mode 100644 index 0000000000000..54c3de252f8a9 --- /dev/null +++ b/tests/micro/zephyr/test_utils.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import json +import pathlib + + +TEMPLATE_PROJECT_DIR = ( + pathlib.Path(__file__).parent + / ".." + / ".." + / ".." + / "apps" + / "microtvm" + / "zephyr" + / "template_project" +).resolve() + +BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" + + +def zephyr_boards() -> dict: + """Returns a dict mapping board to target model""" + with open(BOARDS) as f: + board_properties = json.load(f) + + boards_model = {board: info["model"] for board, info in board_properties.items()} + return boards_model + + +ZEPHYR_BOARDS = zephyr_boards() + + +def qemu_boards(board: str): + """Returns True if board is QEMU.""" + with open(BOARDS) as f: + board_properties = json.load(f) + + qemu_boards = [name for name, board in board_properties.items() if board["is_qemu"]] + return board in qemu_boards + + +def has_fpu(board: str): + """Returns True if board has FPU.""" + with open(BOARDS) as f: + board_properties = json.load(f) + + fpu_boards = [name for name, board in board_properties.items() if board["fpu"]] + return board in fpu_boards diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index b6396ce533158..be1f231156adb 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -14,14 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import logging import os import pathlib import subprocess import sys import logging -import json import pytest import numpy as np @@ -29,19 +27,12 @@ from PIL import Image import tvm -import tvm.rpc -import tvm.micro -import tvm.testing import tvm.relay as relay from tvm.relay.testing import byoc - from tvm.contrib import utils -from tvm.relay.expr_functor import ExprMutator -from tvm.relay.op.annotation import compiler_begin, compiler_end - from tvm.micro.testing import check_tune_log -import conftest +import test_utils _LOG = logging.getLogger(__name__) @@ -58,16 +49,24 @@ def _make_sess_from_op( def _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config): + config_main_stack_size = None + if test_utils.qemu_boards(zephyr_board): + config_main_stack_size = 1536 + + project_options = { + "project_type": "host_driven", + "west_cmd": west_cmd, + "verbose": bool(build_config.get("debug")), + "zephyr_board": zephyr_board, + } + if config_main_stack_size is not None: + project_options["config_main_stack_size"] = config_main_stack_size + project = tvm.micro.generate_project( - str(conftest.TEMPLATE_PROJECT_DIR), + str(test_utils.TEMPLATE_PROJECT_DIR), mod, temp_dir / "project", - { - "project_type": "host_driven", - "west_cmd": west_cmd, - "verbose": bool(build_config.get("debug")), - "zephyr_board": zephyr_board, - }, + project_options, ) project.build() project.flash() @@ -89,7 +88,7 @@ def _make_add_sess(temp_dir, model, zephyr_board, west_cmd, build_config, dtype= def test_add_uint(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. @@ -109,22 +108,12 @@ def test_basic_add(sess): test_basic_add(sess) -def has_fpu(zephyr_board): - sys.path.insert(0, str(conftest.TEMPLATE_PROJECT_DIR)) - try: - import microtvm_api_server - finally: - sys.path.pop(0) - - return microtvm_api_server.Handler._has_fpu(zephyr_board) - - # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro def test_add_float(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" - model = conftest.ZEPHYR_BOARDS[board] - if not has_fpu(board): + model = test_utils.ZEPHYR_BOARDS[board] + if not test_utils.has_fpu(board): pytest.skip(f"FPU not enabled for {board}") build_config = {"debug": tvm_debug} @@ -150,7 +139,7 @@ def test_basic_add(sess): def test_platform_timer(temp_dir, board, west_cmd, tvm_debug): """Test compiling the on-device runtime.""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. @@ -178,7 +167,7 @@ def test_basic_add(sess): @tvm.testing.requires_micro def test_relay(temp_dir, board, west_cmd, tvm_debug): """Testing a simple relay graph""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} shape = (10,) dtype = "int8" @@ -209,7 +198,7 @@ def test_relay(temp_dir, board, west_cmd, tvm_debug): @tvm.testing.requires_micro def test_onnx(temp_dir, board, west_cmd, tvm_debug): """Testing a simple ONNX model.""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} this_dir = pathlib.Path(os.path.dirname(__file__)) @@ -286,7 +275,7 @@ def check_result( @tvm.testing.requires_micro def test_byoc_microtvm(temp_dir, board, west_cmd, tvm_debug): """This is a simple test case to check BYOC capabilities of microTVM""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} x = relay.var("x", shape=(10, 10)) w0 = relay.var("w0", shape=(10, 10)) @@ -366,7 +355,7 @@ def _make_add_sess_with_shape(temp_dir, model, zephyr_board, west_cmd, shape, bu @tvm.testing.requires_micro def test_rpc_large_array(temp_dir, board, west_cmd, tvm_debug, shape): """Test large RPC array transfer.""" - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. @@ -385,9 +374,8 @@ def test_tensors(sess): @tvm.testing.requires_micro def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): """Test AutoTune for microTVM Zephyr""" - import tvm.relay as relay - - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] + build_config = {"debug": tvm_debug} # Create a Relay model data_shape = (1, 3, 16, 16) @@ -420,18 +408,22 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target) assert len(tasks) > 0 - repo_root = pathlib.Path( - subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() - ) - template_project_dir = repo_root / "apps" / "microtvm" / "zephyr" / "template_project" + config_main_stack_size = None + if test_utils.qemu_boards(board): + config_main_stack_size = 1536 + + project_options = { + "zephyr_board": board, + "west_cmd": west_cmd, + "verbose": 1, + "project_type": "host_driven", + } + if config_main_stack_size is not None: + project_options["config_main_stack_size"] = config_main_stack_size + module_loader = tvm.micro.AutoTvmModuleLoader( - template_project_dir=template_project_dir, - project_options={ - "zephyr_board": board, - "west_cmd": west_cmd, - "verbose": 1, - "project_type": "host_driven", - }, + template_project_dir=test_utils.TEMPLATE_PROJECT_DIR, + project_options=project_options, ) timeout = 200 @@ -473,21 +465,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): lowered = tvm.relay.build(mod, target=target, params=params) temp_dir = utils.tempdir() - project = tvm.micro.generate_project( - str(template_project_dir), - lowered, - temp_dir / "project", - { - "zephyr_board": board, - "west_cmd": west_cmd, - "verbose": 1, - "project_type": "host_driven", - }, - ) - project.build() - project.flash() - - with tvm.micro.Session(project.transport()) as session: + with _make_session(temp_dir, board, west_cmd, lowered, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( lowered.get_graph_json(), session.get_system_lib(), session.device ) @@ -502,21 +480,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): lowered_tuned = tvm.relay.build(mod, target=target, params=params) temp_dir = utils.tempdir() - project = tvm.micro.generate_project( - str(template_project_dir), - lowered_tuned, - temp_dir / "project", - { - "zephyr_board": board, - "west_cmd": west_cmd, - "verbose": 1, - "project_type": "host_driven", - }, - ) - project.build() - project.flash() - - with tvm.micro.Session(project.transport()) as session: + with _make_session(temp_dir, board, west_cmd, lowered_tuned, build_config) as session: graph_mod = tvm.micro.create_local_graph_executor( lowered_tuned.get_graph_json(), session.get_system_lib(), session.device ) diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index 6c72d3d7becf8..f03b8ecce6d04 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import io import logging import os @@ -28,35 +27,19 @@ import numpy as np import tvm -import tvm.rpc -import tvm.micro from tvm.micro.project_api import server -import tvm.testing import tvm.relay as relay -from tvm.contrib import utils from tvm.contrib.download import download_testdata from tvm.micro.interface_api import generate_c_interface_header -import conftest - -_LOG = logging.getLogger(__name__) +import test_utils def _build_project(temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None): - template_project_dir = ( - pathlib.Path(__file__).parent - / ".." - / ".." - / ".." - / "apps" - / "microtvm" - / "zephyr" - / "template_project" - ).resolve() project_dir = temp_dir / "project" project = tvm.micro.generate_project( - str(template_project_dir), + str(test_utils.TEMPLATE_PROJECT_DIR), mod, project_dir, { @@ -145,7 +128,7 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug): ]: pytest.skip(msg="Model does not fit.") - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] input_shape = (1, 32, 32, 3) output_shape = (1, 10) build_config = {"debug": tvm_debug} @@ -227,7 +210,7 @@ def test_qemu_make_fail(temp_dir, board, west_cmd, tvm_debug): if board not in ["qemu_x86", "mps2_an521"]: pytest.skip(msg="Only for QEMU targets.") - model = conftest.ZEPHYR_BOARDS[board] + model = test_utils.ZEPHYR_BOARDS[board] build_config = {"debug": tvm_debug} shape = (10,) dtype = "float32" diff --git a/tests/python/contrib/test_ethosu/__init__.py b/tests/python/contrib/test_ethosu/__init__.py new file mode 100644 index 0000000000000..e23e5fc926b24 --- /dev/null +++ b/tests/python/contrib/test_ethosu/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test infrastructure for Arm(R) Ethos(TM)-U NPU related tests""" diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index fc795c066cb6b..aeed64ad4aec3 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -24,15 +24,31 @@ the command stream and perform an equivalency check for single operator test cases. """ +from typing import List +import os +import struct import numpy from enum import IntEnum +from ethosu.vela.register_command_stream_generator import CmdMode +from ethosu.vela.register_command_stream_generator import cmd0 +from ethosu.vela.register_command_stream_generator import cmd1 + import tvm from tvm import relay import tvm.relay.backend.contrib.ethosu.op as ethosu_ops from tvm.topi.nn.utils import get_pad_tuple +from tests.python.relay.aot.aot_test_utils import ( + AOTCompiledTestModel, + AOTDataLinkage, + AOTTestModel, + AOTTestRunner, + compile_models, + run_and_check, +) + class AttachType(IntEnum): kGroupRoot = 1 @@ -42,6 +58,218 @@ class AttachType(IntEnum): kScanUpdate = 5 +class VelaArtifacts: + def __init__(self): + self.cs = dict() + self.flash = dict() + self.sram = dict() + self.npu_ops = set() + + +def parse_relay_tflite_model(tflite_model, input_tensor, input_shape, input_dtype): + mod_, params_ = relay.frontend.from_tflite( + tflite_model, + shape_dict={input_tensor: input_shape}, + dtype_dict={input_tensor: input_dtype}, + ) + return mod_, params_ + + +def parse_tflite_model(model_file): + try: + import tflite + + return tflite.Model.GetRootAsModel(model_file, 0) + except AttributeError: + import tflite.Model + + return tflite.Model.Model.GetRootAsModel(model_file, 0) + + +def print_payload(payload): + cmds = deserialize_command_stream(payload) + for cmd_val in cmds: + cmd, val = parse_cmd(cmd_val) + s = str(cmd) + s = s.ljust(40) + s += str(val) + print(s) + + +def parse_cmd(binary_cmd): + code = binary_cmd[0] & 0x0000FFFF # lower 16 bits + param = binary_cmd[0] >> 16 # higher 16 bits + payload_mode = CmdMode(code & CmdMode.Mask) + if payload_mode == CmdMode.Payload32: + command = cmd1(code & CmdMode.CmdOpMask) + value = binary_cmd[1] + else: + command = cmd0(code & CmdMode.CmdOpMask) + value = param + return command, value + + +def check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_cmds=None): + if ignore_cmds is None: + ignore_cmds = [] + if vela_value != tvm_value and vela_cmd not in ignore_cmds: + raise RuntimeError( + "ValueMismatch :: vela={}, tvm={} for command:{}".format( + vela_value, tvm_value, vela_cmd + ) + ) + + +def verify_cmms(cmms_tvm_blob, cmms_vela_blob): + vela_cmm = deserialize_command_stream(cmms_vela_blob) + tvm_cmm = deserialize_command_stream(cmms_tvm_blob) + cmms_zip = zip(vela_cmm, tvm_cmm) + + first_ifm_found = False + last_ofm_found = False + + ignore_commands = ( + cmd1.NPU_SET_DMA0_SRC, + cmd1.NPU_SET_DMA0_DST, + cmd1.NPU_SET_WEIGHT_BASE, + cmd1.NPU_SET_OFM_BASE0, + cmd1.NPU_SET_IFM_BASE0, + cmd1.NPU_SET_SCALE_BASE, + ) + + ofm_region_params = [] + ofm_bases = [] + for vela_cmm, tvm_cmm in cmms_zip: + vela_cmd, vela_value = parse_cmd(vela_cmm) + tvm_cmd, tvm_value = parse_cmd(tvm_cmm) + + assert vela_cmd == tvm_cmd + + # The first IFM region could be different, but it needs to be 1 and 3. + if vela_cmd == cmd0.NPU_SET_IFM_REGION and not first_ifm_found: + if vela_value == 1 and tvm_value == 3: + first_ifm_found = True + continue + + if vela_cmd == cmd1.NPU_SET_IFM_BASE0 and not first_ifm_found: + if tvm_value != 0: + raise RuntimeError("ValueError :: tvm primary ifm base should be zero") + continue + + # OFM regions should be cached to be checked later + if vela_cmd == cmd0.NPU_SET_OFM_REGION: + ofm_region_params.append((vela_value, tvm_value)) + continue + + # OFM bases should be cached to be checked later + if vela_cmd == cmd1.NPU_SET_OFM_BASE0: + ofm_bases.append((vela_value, tvm_value)) + continue + + check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_commands) + + # The last OFM region could be different but it should be 1 and 4. + last_vela_ofm_region, last_tvm_ofm_region = ofm_region_params.pop(-1) + if not (last_vela_ofm_region == 1 and last_tvm_ofm_region == 4): + raise RuntimeError( + "ValueMismatch :: vela={}, tvm={} for last ofm region it should be 1 and 4 respectively".format( + last_vela_ofm_region, last_tvm_ofm_region + ) + ) + + # The rest of the OFM regions should be the same. + for vela_value, tvm_value in ofm_region_params: + check_cmms_equivalency(vela_cmd, vela_value, tvm_value, ignore_commands) + + # The last OFM base should be zero for tvm + _, last_tvm_ofm_base = ofm_bases.pop(-1) + if not last_tvm_ofm_base == 0: + raise RuntimeError("ValueError :: tvm primary ofm base should be zero") + + +def deserialize_command_stream(blob): + assert isinstance(blob, bytes) + payload_bytes = struct.unpack("<{0}I".format(len(blob) // 4), blob) + cmms = [] + # remove_header + payload_bytes = payload_bytes[8:] + idx = 0 + while idx < len(payload_bytes): + cmd = [] + code = payload_bytes[idx] + idx += 1 + cmd.append(code) + payload_mode = CmdMode(code & CmdMode.Mask) + if payload_mode == CmdMode.Payload32: + value = payload_bytes[idx] + idx += 1 + cmd.append(value) + cmms.append(cmd) + return cmms + + +def _create_test_runner(accel): + file_dir = os.path.dirname(os.path.abspath(__file__)) + test_root = os.path.join(file_dir, "reference_system") + ethosu_macs = accel[accel.rfind("-") + 1 :] + return AOTTestRunner( + makefile="corstone300", + prologue=""" + uart_init(); + EthosuInit(); + """, + includes=["uart.h", "ethosu_55.h", "ethosu_mod.h", "hard_fault.h"], + parameters={"ETHOSU_TEST_ROOT": test_root, "NPU_VARIANT": ethosu_macs}, + pass_config={ + "relay.ext.ethosu.options": { + "accelerator_config": accel, + } + }, + ) + + +def build_source(module, inputs, outputs, accel="ethos-u55-256"): + test_runner = _create_test_runner(accel) + return compile_models( + models=AOTTestModel( + module=module, + inputs=inputs, + outputs=outputs, + output_tolerance=10, + extra_memory_in_bytes=16 * 1024 * 1024, + ), + interface_api="c", + use_unpacked_api=True, + workspace_byte_alignment=16, + pass_config=test_runner.pass_config, + ) + + +def verify_source( + models: List[AOTCompiledTestModel], + accel="ethos-u55-256", +): + """ + This method verifies the generated source from an NPU module by building it and running on an FVP. + """ + interface_api = "c" + test_runner = _create_test_runner(accel) + run_and_check( + models, + test_runner, + interface_api, + workspace_byte_alignment=16, + data_linkage=AOTDataLinkage(section="ethosu_scratch", alignment=16), + ) + + +def flatten_numpy_data(data): + """Flatten the numpy tensor to be single dimensional""" + total_elements = data.size + reshaped_data = numpy.reshape(data, [total_elements]) + return reshaped_data + + def generate_weights_data(shape, dtype): size = 1 for dim in shape: diff --git a/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake b/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake new file mode 100644 index 0000000000000..6aeb0b7cc7c15 --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/arm-none-eabi-gcc.cmake @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if (__TOOLCHAIN_LOADED) + return() +endif() +set(__TOOLCHAIN_LOADED TRUE) + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_C_COMPILER "arm-none-eabi-gcc") +set(CMAKE_CXX_COMPILER "arm-none-eabi-g++") +set(CMAKE_SYSTEM_PROCESSOR "cortex-m55" CACHE STRING "Select Cortex-M architecture. (cortex-m0, cortex-m3, cortex-m33, cortex-m4, cortex-m55, cortex-m7, etc)") + +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) + +SET(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +SET(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_STANDARD 99) +set(CMAKE_CXX_STANDARD 14) + +# The system processor could for example be set to cortex-m33+nodsp+nofp. +set(__CPU_COMPILE_TARGET ${CMAKE_SYSTEM_PROCESSOR}) +string(REPLACE "+" ";" __CPU_FEATURES ${__CPU_COMPILE_TARGET}) +list(POP_FRONT __CPU_FEATURES CMAKE_SYSTEM_PROCESSOR) + +string(FIND ${__CPU_COMPILE_TARGET} "+" __OFFSET) +if(__OFFSET GREATER_EQUAL 0) + string(SUBSTRING ${__CPU_COMPILE_TARGET} ${__OFFSET} -1 CPU_FEATURES) +endif() + +# Add -mcpu to the compile options to override the -mcpu the CMake toolchain adds +add_compile_options(-mcpu=${__CPU_COMPILE_TARGET}) + +# Set floating point unit +if("${__CPU_COMPILE_TARGET}" MATCHES "\\+fp") + set(FLOAT hard) +elseif("${__CPU_COMPILE_TARGET}" MATCHES "\\+nofp") + set(FLOAT soft) +elseif("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m33" OR + "${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "cortex-m55") + set(FLOAT hard) +else() + set(FLOAT soft) +endif() + +add_compile_options(-mfloat-abi=${FLOAT}) +add_link_options(-mfloat-abi=${FLOAT}) + +# Link target +add_link_options(-mcpu=${__CPU_COMPILE_TARGET}) +add_link_options(-Xlinker -Map=output.map) + +# +# Compile options +# +set(cxx_flags "-fno-unwind-tables;-fno-rtti;-fno-exceptions") + +add_compile_options("-Wall;-Wextra;-Wsign-compare;-Wunused;-Wswitch-default;\ +-Wdouble-promotion;-Wredundant-decls;-Wshadow;-Wnull-dereference;\ +-Wno-format-extra-args;-Wno-unused-function;-Wno-unused-label;\ +-Wno-missing-field-initializers;-Wno-return-type;-Wno-format;-Wno-int-conversion" + "$<$:${cxx_flags}>" + ) diff --git a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm b/tests/python/contrib/test_ethosu/reference_system/ethosu_55.h similarity index 69% rename from apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm rename to tests/python/contrib/test_ethosu/reference_system/ethosu_55.h index eb538f07bf492..41ce284956e2e 100644 --- a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm +++ b/tests/python/contrib/test_ethosu/reference_system/ethosu_55.h @@ -16,31 +16,12 @@ * specific language governing permissions and limitations * under the License. */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ -/*! - * \brief A hook to launch RPC server via xcodebuild test - * \file tvmrpcLauncher.mm - */ - -#import -#import "TVMRuntime.h" - -@interface tvmrpcLauncher : XCTestCase - -@end - -@implementation tvmrpcLauncher - -- (void)setUp { - [super setUp]; -} - -- (void)tearDown { - [super tearDown]; -} - -- (void)testRPC { - [TVMRuntime launchSyncServer]; -} +/* Define Arm(R) Ethos(TM)-U55 specific IRQs & base address */ +#define ETHOSU_NPU_FAIL (1 << 4) +#define ETHOSU_IRQ ((IRQn_Type)56) +#define ETHOSU_BASE_ADDRESS ((void*)0x48102000) -@end +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_55_H_ diff --git a/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h b/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h new file mode 100644 index 0000000000000..aa5c1026bd6d8 --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/ethosu_mod.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ + +#include +// TODO(@grant-arm): Remove device specific information once RTOS support is available +#include +#include + +#include "ethosu_55.h" + +struct ethosu_driver* ethosu0_driver = ðosu_drv; + +void ethosuIrqHandler0() { ethosu_irq_handler(ethosu0_driver); } + +// Initialize Arm(R) Ethos(TM)-U NPU driver +int EthosuInit() { + if (ethosu_init(ethosu0_driver, (void*)ETHOSU_BASE_ADDRESS, NULL, 0, 1, 1)) { + printf("Failed to initialize NPU.\n"); + return -1; + } + + // Display Arm(R) Ethos(TM)-U version information useful for debugging issues + struct ethosu_version version; + ethosu_get_version(ethosu0_driver, &version); + printf( + "version={major=%u, minor=%u, status=%u}, product={major=%u}, arch={major=%u, minor=%u, " + "patch=%u}\n", + version.id.version_major, version.id.version_minor, version.id.version_status, + version.id.product_major, version.id.arch_major_rev, version.id.arch_minor_rev, + version.id.arch_patch_rev); + printf("macs_per_cc=%u, cmd_stream_version=%u, shram_size=%u\n", version.cfg.macs_per_cc, + version.cfg.cmd_stream_version, version.cfg.shram_size); + + // Assumes SCB->VTOR points to RW memory + NVIC_SetVector(ETHOSU_IRQ, (uint32_t)ðosuIrqHandler0); + NVIC_EnableIRQ(ETHOSU_IRQ); + + return 0; +} + +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_ETHOSU_MOD_H_ diff --git a/tests/python/contrib/test_ethosu/reference_system/hard_fault.h b/tests/python/contrib/test_ethosu/reference_system/hard_fault.h new file mode 100644 index 0000000000000..9d349004848be --- /dev/null +++ b/tests/python/contrib/test_ethosu/reference_system/hard_fault.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ +#define TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ + +struct ExcContext { + uint32_t r0; + uint32_t r1; + uint32_t r2; + uint32_t r3; + uint32_t r12; + uint32_t lr; + uint32_t pc; + uint32_t xPsr; +}; +void HardFault_Handler() { + int irq; + struct ExcContext* e; + uint32_t sp; + asm volatile( + "mrs %0, ipsr \n" // Read IPSR (Exception number) + "sub %0, #16 \n" // Get it into IRQn_Type range + "tst lr, #4 \n" // Select the stack which was in use + "ite eq \n" + "mrseq %1, msp \n" + "mrsne %1, psp \n" + "mov %2, sp \n" + : "=r"(irq), "=r"(e), "=r"(sp)); + printf("Hard fault. irq=%d, pc=0x%08lu, lr=0x%08lu, xpsr=0x%08lu, sp=0x%08lu\n", irq, e->pc, + e->lr, e->xPsr, sp); + printf("%11s cfsr=0x%08lu bfar=0x%08lu\n", "", SCB->CFSR, SCB->BFAR); + printf("EXITTHESIM\n"); + while (1 == 1) + ; +} + +#endif // TVM_RUNTIME_CONTRIB_ETHOS_U_HARD_FAULT_H_ diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py new file mode 100644 index 0000000000000..2e0fb4b78cf1a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import os +import numpy as np +import pathlib + +import tvm +import tvm.micro as micro +from tvm import relay +from tvm.relay.backend.contrib import ethosu +from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tests.python.relay.aot.aot_test_utils import generate_ref_data + +from . import relay_ir_builder +from . import infra + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +def infer_type_function_pass(func): + mod = tvm.IRModule() + mod["test"] = func + mod = relay.transform.InferType()(mod) + return mod["test"] + + +def get_shape_expr(in_expr, out_expr): + main_f = relay.Function([in_expr], out_expr) + main_f = infer_type_function_pass(main_f) + shape = [int(i) for i in main_f.body.checked_type.shape] + return shape + + +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +def test_ethosu_conv2d(accel_type): + def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32) + c1_params.kernel.sc = relay.const(np.random.rand(32) * 2, "float32") + c1_params.strides = (1, 1) + c1_params.pad = "VALID" + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + f = relay.Function([input0], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c1_params] + + def create_graph_double(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) + c1_params.strides = (2, 2) + c1_params.pad = "VALID" + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c2_params.ifm.shape = c1_params.ofm.shape + c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) + c2_params.strides = (1, 1) + c2_params.pad = "SAME" + c2_params.update_output_qnn_params() + c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) + c2_params.ofm.shape = get_shape_expr(input0, c2) + + f = relay.Function([input0], c2) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c2_params, c1_params] + + def create_graph_activation(input_tensor_name, input_tensor_shape, input_tensor_dtype): + c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c1_params.ifm.shape = input_tensor_shape + c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8) + c1_params.strides = (2, 2) + c1_params.pad = "VALID" + c1_params.activation = "CLIP" + c1_params.clip_min = 90 + c1_params.clip_max = 110 + c1_params.update_output_qnn_params( + input_tensor_dtype, input_tensor_dtype, input_tensor_dtype + ) + input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, dtype=c1_params.ifm.dtype) + c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0) + c1_params.ofm.shape = get_shape_expr(input0, c1) + + c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype) + c2_params.ifm.shape = c1_params.ofm.shape + c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16) + c2_params.strides = (1, 1) + c2_params.pad = "SAME" + c2_params.update_output_qnn_params() + c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1) + c2_params.ofm.shape = get_shape_expr(input0, c2) + + f = relay.Function([input0], c2) + mod = tvm.IRModule() + mod["main"] = f + return mod, [c2_params, c1_params] + + test_cases = [ + (create_graph_single, ["input", (1, 300, 300, 3), "int8"]), + (create_graph_double, ["input", (1, 128, 256, 4), "int8"]), + (create_graph_activation, ["input", (1, 64, 100, 4), "int8"]), + ] + np.random.seed(42) + for test_case in test_cases: + relay_module, conv_params = test_case[0](*test_case[1]) + input_tensor, input_shape, input_dtype = test_case[1] + mod = partition_for_ethosu(relay_module) + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(input_dtype) + input_data = { + input_tensor: np.random.randint( + in_min, high=in_max, size=input_shape, dtype=input_dtype + ) + } + output_data = generate_ref_data(relay_module, input_data) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 0e546ae2fd24f..eb3a4d8cb4dad 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -26,7 +26,7 @@ from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute -from infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d # fmt: off diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 52f6995c3aaa7..911a0e6eefc63 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -27,7 +27,8 @@ from tvm.relay.backend.contrib.ethosu import legalize, preprocess from tvm.relay.dataflow_pattern import * from tvm.relay.op.contrib.ethosu import * -import relay_ir_builder + +from . import relay_ir_builder def test_split_indices_legalize(): @@ -293,7 +294,7 @@ def verify_linear(ext_func, conv2d_params): ] for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = ethosu.partition_for_ethosu(mod) + mod = partition_for_ethosu(mod) mod = legalize.LegalizeEthosUConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -326,7 +327,7 @@ def create_graph_single_unsupported_ifm_layout( for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) - mod = ethosu.partition_for_ethosu(mod) + mod = partition_for_ethosu(mod) with pytest.raises( tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" ): diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py new file mode 100644 index 0000000000000..a4be923e09dbe --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +from tests.python.relay.aot.aot_test_utils import ( + convert_to_relay, + generate_ref_data, +) +import numpy as np + +import tvm +import tvm.micro as micro +from tvm import relay +from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.op.contrib.ethosu import partition_for_ethosu + +import tvm.relay.testing.tf as tf_testing + +from . import infra + +ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + +def test_forward_mobilenet_v1(accel_type="ethos-u55-256"): + """Test the Mobilenet V1 TF Lite model.""" + np.random.seed(23) + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/" + "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "mobilenet_v1_1.0_224_quant.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + input_tensor = "input" + input_dtype = "uint8" + input_shape = (1, 224, 224, 3) + in_min, in_max = util.get_range_for_dtype_str(input_dtype) + input_data = np.random.randint(in_min, high=in_max, size=input_shape, dtype=input_dtype) + + relay_mod, params = convert_to_relay(tflite_model_buf, input_data, "input") + input_data = {input_tensor: input_data} + output_data = generate_ref_data(relay_mod, input_data) + + mod = partition_for_ethosu(relay_mod, params) + compiled_models = infra.build_source(mod, input_data, output_data, accel_type) + infra.verify_source(compiled_models, accel_type) + + +if __name__ == "__main__": + test_forward_mobilenet_v1() diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 96fe56d1778ef..f66b21b92a03a 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -24,7 +24,7 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader -from infra import make_ethosu_conv2d, get_convolutional_args +from .infra import make_ethosu_conv2d, get_convolutional_args @pytest.mark.parametrize( diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 222dccacc9062..2d76cd654690d 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -25,7 +25,7 @@ from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants -from infra import make_ethosu_conv2d +from .infra import make_ethosu_conv2d # fmt: off diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index b07f8ea7f48b2..8077271ed4964 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -29,7 +29,7 @@ schedule_cache_reads, ) from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants -from infra import AttachType, make_ethosu_conv2d +from .infra import AttachType, make_ethosu_conv2d class TestTEGraph: diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py new file mode 100644 index 0000000000000..479a1032453a7 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -0,0 +1,770 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import numpy as np + +import tvm +from tvm import tir +from tvm.tir import stmt_functor +from tvm.script import ty +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator +from tvm.relay.backend.contrib.ethosu import util +import ethosu.vela.api as vapi + + +# fmt: off +"""A sample tir test case for translator""" +@tvm.script.tir +class SingleEthosUConv2D: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = tir.match_buffer(placeholder_1, [1, 1, 3, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, tir.load("uint8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_4.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) +# fmt: on + + +# fmt: off +"""A sample tir test case with multiple convolutions for translator""" +@tvm.script.tir +class MultiEthosUConv2D: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_9 = tir.match_buffer(placeholder_3, [1, 1, 32, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_7 = tir.match_buffer(placeholder_1, [1, 1, 3, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_6 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_8 = tir.match_buffer(placeholder_2, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_4, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_conv2d_2 = tir.allocate([1024], "uint8", "global") + ethosu_conv2d_3 = tir.allocate([2048], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, tir.load("uint8", placeholder_6.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_7.data, 0), 0, 12, tir.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_9.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, tir.load("uint8", placeholder_6.data, 96), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_7.data, 0), 0, 12, tir.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_9.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) +# fmt: on + + +# fmt: off +"""A sample tir test case with copy operations for translator""" +@tvm.script.tir +class MultiEthosUCopy: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_2, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = tir.match_buffer(placeholder_1, [8, 1, 1, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([256], "uint8", "global") + placeholder_d_global = tir.allocate([8], "int32", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", placeholder_4.data, 0), 256, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("int32", placeholder_5.data, 0), 8, tir.load("int32", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 0, 12, tir.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) +# fmt: on + + +# fmt: off +"""A TIR test module of weight streaming""" +@tvm.script.tir +class WeightStreamOnly: + def main(placeholder: ty.handle, ethosu_conv2d: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer_4 = tir.match_buffer(placeholder_5, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_7, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_1, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_3, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([144], "uint8", "global") + placeholder_d_global = tir.allocate([20], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_7.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A TIR test module of weight streaming and direct reading""" +@tvm.script.tir +class MixedRead: + def main(placeholder: ty.handle, placeholder_1: ty.handle, ethosu_conv2d: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer_5 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_9 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_10, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_8 = tir.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_conv2d_2 = tir.allocate([4096], "uint8", "global") + placeholder_global = tir.allocate([80], "uint8", "global") + placeholder_d_global = tir.allocate([20], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_5.data, 0), 592, 12, tir.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_9.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_8.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_buffer_info_extraction(): + test_cases = [ + { + # Stimulus + "tir_module": SingleEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 16], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + # Reference Outputs + "constants": { + "placeholder_4": 1, + "placeholder_5": 2, + }, + "data_buffers": { + "placeholder_3": ( + [1, 8, 8, 3], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_1": ( + [1, 8, 8, 16], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + }, + }, + { + "tir_module": MultiEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32"), + 3: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 32, 8], "uint8" + ), + 4: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference Outputs + "constants": { + "placeholder_5": 4, + "placeholder_7": 1, + "placeholder_8": 2, + "placeholder_9": 3, + }, + "data_buffers": { + "placeholder_6": ( + [1, 8, 8, 3], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_1": ( + [1, 8, 8, 8], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_2": ([1024], "uint8", tir_to_cs_translator.BufferType.scratch), + "ethosu_conv2d_3": ([2048], "uint8", tir_to_cs_translator.BufferType.scratch), + }, + }, + ] + for test_case in test_cases: + buffer_info = tir_to_cs_translator.extract_buffer_info( + test_case["tir_module"], test_case["param_dict"] + ) + for buffer_var, info in buffer_info.items(): + buffer_name = buffer_var.name + if buffer_name in test_case["constants"].keys(): + assert ( + info.values == test_case["param_dict"][test_case["constants"][buffer_name]] + ).all() + assert ( + info.dtype == test_case["param_dict"][test_case["constants"][buffer_name]].dtype + ) + info.btype == tir_to_cs_translator.BufferType.constant + else: + assert list(info.shape) == test_case["data_buffers"][buffer_name][0] + assert info.dtype == test_case["data_buffers"][buffer_name][1] + assert info.btype == test_case["data_buffers"][buffer_name][2] + + +def test_translate_ethosu_conv2d(): + test_cases = [ + { + # Stimulus + "tir_module": SingleEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 16], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + # Reference outputs + "ref": [ + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(8, 8, 3), + "tiles": vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(8, 8, 16), + "tiles": vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(128, 16, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + } + ], + }, + { + "tir_module": MultiEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32"), + 3: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 32, 8], "uint8" + ), + 4: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference Outputs + "ref": [ + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 3), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": {"op": None}, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 8), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(64, 8, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 3), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 8), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(64, 8, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + ], + }, + ] + + def extract_ethosu_conv2d_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_conv2d_calls = list() + + def populate_ethosu_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_conv2d" + ): + ethosu_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls) + return ethosu_conv2d_calls + + for test_case in test_cases: + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_case["tir_module"]) + for idx, ethosu_conv2d_call in enumerate(ethosu_conv2d_calls): + ref = test_case["ref"][idx] + npu_op, w_zero_point = tir_to_cs_translator.translate_ethosu_conv2d(ethosu_conv2d_call) + # Compare IFM + assert npu_op.ifm.data_type == ref["ifm"]["data_type"] + assert npu_op.ifm.shape == ref["ifm"]["shape"] + assert npu_op.ifm.tiles.height_0 == ref["ifm"]["tiles"].height_0 + assert npu_op.ifm.tiles.height_1 == ref["ifm"]["tiles"].height_1 + assert npu_op.ifm.tiles.width_0 == ref["ifm"]["tiles"].width_0 + assert npu_op.ifm.quantization == ref["ifm"]["quantization"] + assert npu_op.ifm.layout == ref["ifm"]["layout"] + assert npu_op.ifm.strides == ref["ifm"]["strides"] + # Compare OFM + assert npu_op.ofm.data_type == ref["ofm"]["data_type"] + assert npu_op.ofm.shape == ref["ofm"]["shape"] + assert npu_op.ofm.tiles.height_0 == ref["ofm"]["tiles"].height_0 + assert npu_op.ofm.tiles.height_1 == ref["ofm"]["tiles"].height_1 + assert npu_op.ofm.tiles.width_0 == ref["ofm"]["tiles"].width_0 + assert npu_op.ofm.quantization == ref["ofm"]["quantization"] + assert npu_op.ofm.layout == ref["ofm"]["layout"] + assert npu_op.ofm.strides == ref["ofm"]["strides"] + # Compare kernel and padding + assert npu_op.kernel.__dict__ == ref["kernel"].__dict__ + assert npu_op.padding == ref["padding"] + # Compare activation + if ref["activation"]["op"] is None: + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == ref["activation"]["op"] + assert npu_op.activation.min == ref["activation"]["min"] + assert npu_op.activation.max == ref["activation"]["max"] + # Compare ifm upscaling + assert npu_op.ifm_upscale == ref["ifm_upscale"] + # Compare weight quantization parameters + assert w_zero_point == ref["w_zero_point"] + + +def test_translate_ethosu_copy(): + def extract_ethosu_copy_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_copy_calls = list() + + def populate_ethosu_copy_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_copy" + ): + ethosu_copy_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_copy_calls) + return ethosu_copy_calls + + test_cases = [ + { + # Stimulus + "tir_module": MultiEthosUCopy(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [8, 1, 1, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference outputs + "ref": [ + { + "src": "placeholder_4", + "dest": "placeholder_global", + "length": 256, + }, + { + "src": "placeholder_5", + "dest": "placeholder_d_global", + "length": 8, + }, + ], + }, + ] + + for test_case in test_cases: + ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) + for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): + npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_extern_call(ethosu_copy_call) + assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] + assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] + assert npu_dma_op.src.length == test_case["ref"][idx]["length"] + assert npu_dma_op.dest.length == test_case["ref"][idx]["length"] + + +def test_assign_addresses(): + test_cases = [ + { + # Stimulus + "tir_module": WeightStreamOnly(), + "param_dict": { + 2: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + }, + }, + { + # Stimulus + "tir_module": MixedRead(), + "param_dict": { + 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [592], "uint8"), + 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"), + 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 10: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 11: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + }, + }, + ] + + def extract_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + extern_calls = list() + + def populate_extern_calls(stmt): + if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": + extern_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) + return extern_calls + + def collect_tir_buffer_info(npu_ops): + """This is run prior to address assigning to collect tir buffer information + for verification later on""" + _npu_op_tir_buffers = dict() + for npu_op in npu_ops: + if isinstance(npu_op, vapi.NpuDmaOperation): + _npu_op_tir_buffers[npu_op] = (npu_op.src.address, npu_op.dest.address) + elif issubclass(type(npu_op), vapi.NpuBlockOperation): + _npu_op_tir_buffers[npu_op] = ( + npu_op.ifm.tiles.addresses[0], + npu_op.ofm.tiles.addresses[0], + npu_op.weights, + npu_op.biases, + ) + return _npu_op_tir_buffers + + def _check_buffer(address, region, length, buffer_var): + """Checks whether the buffer information is valid with + original tir buffers. + - If its constant, this will check + the slice in the constant tensor has the values. + - If its scratch, this will check + the slice is within scratch and does not have conflicts + with other scratch tensors. + - If its input/output, this will check the + address is zero + """ + inverse_region_map = { + 0: tir_to_cs_translator.BufferType.constant, + 1: tir_to_cs_translator.BufferType.scratch, + 3: tir_to_cs_translator.BufferType.input, + 4: tir_to_cs_translator.BufferType.output, + } + buffer_type = inverse_region_map[region] + if buffer_type == tir_to_cs_translator.BufferType.constant: + ref = buffer_info[buffer_var].values + assert (constant_tensor[address : address + length] == ref).all() + # Every buffer is adjusted to align to 16 bytes + length = util.round_up(length, 16) + # Mark these constants are read at least once + constant_tensor_read_mask[address : address + length] = np.ones(length, dtype="uint8") + elif buffer_type == tir_to_cs_translator.BufferType.scratch: + shape = list(buffer_info[buffer_var].shape) + assert length == np.prod(shape) + assert address < scratch_size + # Every buffer is adjusted to align to 16 bytes + length = util.round_up(length, 16) + assert address + length <= scratch_size + # The scratch area should not be used by anyother buffer + assert not scratch_allocation_mask[address : address + length].any() + # The scratch area is marked as used + scratch_allocation_mask[address : address + length] = np.ones(length, dtype="uint8") + elif buffer_type == tir_to_cs_translator.BufferType.input: + assert address == 0 + else: + assert buffer_type == tir_to_cs_translator.BufferType.output + assert address == 0 + + def verify(npu_ops): + """This wrapper verifies the allocated addresses matches with original tir buffers""" + checked_buffers = set() + + def check_buffer(address, region, length, buffer_var): + if buffer_var not in checked_buffers: + _check_buffer(address, region, length, buffer_var) + checked_buffers.add(buffer_var) + + for npu_op in npu_ops: + if isinstance(npu_op, vapi.NpuDmaOperation): + src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + check_buffer( + npu_op.src.address, npu_op.src.region, npu_op.src.length, src_tir_buffer_var + ) + dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer_var + check_buffer( + npu_op.dest.address, + npu_op.dest.region, + npu_op.dest.length, + dest_tir_load, + ) + elif issubclass(type(npu_op), vapi.NpuBlockOperation): + ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + ifm_length = ( + npu_op.ifm.shape.height * npu_op.ifm.shape.width * npu_op.ifm.shape.depth + ) + check_buffer( + npu_op.ifm.tiles.addresses[0], + npu_op.ifm.region, + ifm_length, + ifm_tir_buffer_var, + ) + ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer_var + ofm_length = ( + npu_op.ofm.shape.height * npu_op.ofm.shape.width * npu_op.ofm.shape.depth + ) + check_buffer( + npu_op.ofm.tiles.addresses[0], + npu_op.ofm.region, + ofm_length, + ofm_tir_buffer_var, + ) + for idx, weight in enumerate(npu_op_tir_buffers[npu_op][2]): + assert isinstance(weight, vapi.NpuAddressRange) + check_buffer( + npu_op.weights[idx].address, + npu_op.weights[idx].region, + npu_op.weights[idx].length, + weight.address.buffer_var, + ) + for idx, bias in enumerate(npu_op_tir_buffers[npu_op][3]): + assert isinstance(bias, vapi.NpuAddressRange) + check_buffer( + npu_op.biases[idx].address, + npu_op.biases[idx].region, + npu_op.biases[idx].length, + bias.address.buffer_var, + ) + + for test_case in test_cases: + buffer_info = tir_to_cs_translator.extract_buffer_info( + test_case["tir_module"], test_case["param_dict"] + ) + extern_calls = extract_extern_calls(test_case["tir_module"]) + _npu_ops = list() + for extern_call in extern_calls: + _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_extern_call(extern_call)) + npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) + _npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses( + buffer_info, _npu_ops + ) + scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8") + constant_tensor_read_mask = np.zeros(constant_tensor.size, dtype="uint8") + verify(_npu_ops) + # This will be only 1 if all allocated scratch is used. + assert np.prod(scratch_allocation_mask) == 1 + # This will be only 1 if all constant tensors is read at least once. + assert np.prod(constant_tensor_read_mask) == 1 + + +if __name__ == "__main__": + test_buffer_info_extraction() + test_translate_ethosu_conv2d() + test_translate_ethosu_copy() + test_assign_addresses() diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index a86dd919d5caf..1f9c72e1c9cce 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -347,9 +347,7 @@ def verify(test_vec, mock_obj): assert mock_obj.call_args[1]["block_traversal"] == test_vec["block_traversal"] def create_mock(test_vec): - with patch( - "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights" - ) as mock_npu_encode_weights: + with patch("ethosu.vela.api.npu_encode_weights") as mock_npu_encode_weights: ifm_bitdepth = np.iinfo(test_vec["ifm_dtype"]).bits ifm_dtype = test_vec["ifm_dtype"] max = np.iinfo(ifm_dtype).max @@ -427,9 +425,7 @@ def verify(test_vec, mock_obj, packed_biases): assert test_vec["hw_shifts"][idx] == mock_obj.call_args_list[idx][0][2] def create_mock(test_vec): - with patch( - "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_bias" - ) as mock_npu_encode_bias: + with patch("ethosu.vela.api.npu_encode_bias") as mock_npu_encode_bias: mock_npu_encode_bias.return_value = bytearray(10) ifm_dtype = test_vec["ifm_dtype"] max = np.iinfo(ifm_dtype).max @@ -507,12 +503,8 @@ def test_encode_weights(accel): ] def create_mock(test_vec): - with patch( - "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights" - ) as mock_enc_w: - with patch( - "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_find_block_configs" - ) as mock_blk_cfg: + with patch("ethosu.vela.api.npu_encode_weights") as mock_enc_w: + with patch("ethosu.vela.api.npu_find_block_configs") as mock_blk_cfg: mock_blk_cfg.return_value = [vapi.NpuShape3D(8, 8, 8)] ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"]) buffer_info = tirtocs.extract_buffer_info( diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index defd628c60c92..3993fdae86f4c 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -361,6 +361,41 @@ def test_compile_tflite_module_with_external_codegen_vitis_ai(tflite_mobilenet_v assert os.path.exists(dumps_path) +def test_compile_tflite_module_with_external_codegen_ethosu( + tmpdir_factory, tflite_mobilenet_v1_1_quant +): + pytest.importorskip("tflite") + pytest.importorskip("ethosu.vela") + ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32"] + + output_dir = tmpdir_factory.mktemp("mlf") + + tvmc_model = tvmc.load(tflite_mobilenet_v1_1_quant) + + for accel_type in ACCEL_TYPES: + output_file_name = f"{output_dir}/file_{accel_type}.tar" + + tvmc_package = tvmc.compiler.compile_model( + tvmc_model, + target=f"ethos-u -accelerator_config={accel_type}, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + output_format="mlf", + package_path=output_file_name, + pass_context_configs=["tir.disable_vectorize=true"], + ) + + # check whether an MLF package was created + assert os.path.exists(output_file_name) + + # check whether the expected number of C sources are in the tarfile + with tarfile.open(output_file_name) as mlf_package: + c_source_files = [ + name + for name in mlf_package.getnames() + if re.match(r"\./codegen/host/src/\D+\d+\.c", name) + ] + assert len(c_source_files) == 17 + + @mock.patch("tvm.relay.build") @mock.patch("tvm.driver.tvmc.composite_target.get_codegen_by_target") @mock.patch("tvm.driver.tvmc.load") diff --git a/tests/python/driver/tvmc/test_tvmc_common.py b/tests/python/driver/tvmc/test_tvmc_common.py index 31fa688ad7178..779611a7a3457 100644 --- a/tests/python/driver/tvmc/test_tvmc_common.py +++ b/tests/python/driver/tvmc/test_tvmc_common.py @@ -165,6 +165,10 @@ def test_shape_parser(): shape_string = "input:[10,10,10] input2:[20,20,20,20]" shape_dict = tvmc.common.parse_shape_string(shape_string) assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} + # Check that multiple valid input shapes with colons are parse correctly + shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" + shape_dict = tvmc.common.parse_shape_string(shape_string) + assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} # Check that alternate syntax parses correctly shape_string = "input: [10, 10, 10] input2: [20, 20, 20, 20]" shape_dict = tvmc.common.parse_shape_string(shape_string) @@ -193,6 +197,10 @@ def test_shape_parser(): tvmc.common.parse_shape_string(shape_string) # Check that input with a invalid slash raises error. shape_string = "gpu_0/data_0:5,10 /:10,10" + with pytest.raises(argparse.ArgumentTypeError): + tvmc.common.parse_shape_string(shape_string) + # Check that input with a invalid colon raises error. + shape_string = "gpu_0/data_0:5,10 :test:10,10" with pytest.raises(argparse.ArgumentTypeError): tvmc.common.parse_shape_string(shape_string) # Check that input with a invalid slash raises error. diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 91d3911da5308..1cf6ffff762ce 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1840,7 +1840,16 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") - verify_with_ort_with_inputs(model, [data], [outshape], opset=11, target=target, dev=dev) + verify_with_ort_with_inputs( + model, + [data], + [outshape], + opset=11, + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) funcs = [ "ReduceMax", @@ -1998,8 +2007,12 @@ def verify_binary_ops(op, x, y, out_type="float32"): verify_binary_ops("Sum", x, z) verify_binary_ops("Greater", x, y, "bool") verify_binary_ops("Greater", x, z, "bool") + verify_binary_ops("GreaterOrEqual", x, y, "bool") + verify_binary_ops("GreaterOrEqual", x, z, "bool") verify_binary_ops("Less", x, y, "bool") verify_binary_ops("Less", x, z, "bool") + verify_binary_ops("LessOrEqual", x, y, "bool") + verify_binary_ops("LessOrEqual", x, z, "bool") verify_binary_ops("Equal", x, y, "bool") verify_binary_ops("Equal", x, z, "bool") @@ -3457,6 +3470,49 @@ def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad=" ) +def verify_global_lppool(x_shape, p, out_shape, target, dev): + pool_node = helper.make_node( + "GlobalLpPool", + inputs=["x"], + outputs=["y"], + p=p, + ) + + graph = helper.make_graph( + [pool_node], + "global_lppool_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="global_lppool_test") + verify_with_ort( + model, [x_shape], out_shape, use_vm=True, convert_to_static=True, target=target, dev=dev + ) + + +@tvm.testing.parametrize_targets +def test_global_lppool(target, dev): + + # LpPool1D + verify_global_lppool(x_shape=[1, 15, 16], p=2, out_shape=[1, 15, 1], target=target, dev=dev) + + # LpPool2D + verify_global_lppool( + x_shape=[1, 15, 32, 32], p=2, out_shape=[1, 15, 1, 1], target=target, dev=dev + ) + + # LpPool2D + verify_global_lppool( + x_shape=[1, 15, 32, 32], p=3, out_shape=[1, 15, 1, 1], target=target, dev=dev + ) + + # LpPool3D + verify_global_lppool( + x_shape=[1, 15, 3, 32, 32], p=2, out_shape=[1, 15, 1, 1, 1], target=target, dev=dev + ) + + def verify_rnn( seq_length, batch_size, @@ -4858,10 +4914,6 @@ def verify_eyelike(indata): "test_cast_FLOAT_to_BFLOAT16", "test_cast_FLOAT_to_STRING", "test_cast_STRING_to_FLOAT", - "test_compress_0", - "test_compress_1", - "test_compress_default_axis", - "test_compress_negative_axis", "test_convtranspose_dilations", "test_convtranspose_output_shape", "test_cumsum_1d", @@ -4877,18 +4929,7 @@ def verify_eyelike(indata): "test_dropout_default_mask", "test_dropout_default_mask_ratio", "test_dropout_default_ratio", - "test_greater_equal", - "test_greater_equal_bcast", "test_if_seq", - "test_less_equal", - "test_less_equal_bcast", - "test_logsoftmax_axis_0_expanded", - "test_logsoftmax_axis_1_expanded", - "test_logsoftmax_axis_2_expanded", - "test_logsoftmax_default_axis_expanded", - "test_logsoftmax_example_1_expanded", - "test_logsoftmax_large_number_expanded", - "test_logsoftmax_negative_axis_expanded", "test_loop11", "test_loop13_seq", "test_matmulinteger", @@ -4944,58 +4985,14 @@ def verify_eyelike(indata): "test_round", "test_scan9_sum", "test_scan_sum", - # With reduce_sum supported fully, these expanded tests should pass - "test_sce_NCd1_mean_weight_negative_ii_expanded", - "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", - "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", - "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", - "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", - "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", - "test_sce_NCd1d2d3d4d5_mean_weight_expanded", - "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", - "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", - "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", - "test_sce_mean_3d_expanded", - "test_sce_mean_3d_log_prob_expanded", - "test_sce_mean_expanded", - "test_sce_mean_log_prob_expanded", - "test_sce_mean_no_weight_ii_3d_expanded", - "test_sce_mean_no_weight_ii_3d_log_prob_expanded", - "test_sce_mean_no_weight_ii_4d_expanded", - "test_sce_mean_no_weight_ii_4d_log_prob_expanded", - "test_sce_mean_no_weight_ii_expanded", - "test_sce_mean_no_weight_ii_log_prob_expanded", - "test_sce_mean_weight_expanded", - "test_sce_mean_weight_ii_3d_expanded", - "test_sce_mean_weight_ii_3d_log_prob_expanded", - "test_sce_mean_weight_ii_4d_expanded", - "test_sce_mean_weight_ii_4d_log_prob_expanded", - "test_sce_mean_weight_ii_expanded", - "test_sce_mean_weight_ii_log_prob_expanded", - "test_sce_mean_weight_log_prob_expanded", - "test_sce_none_expanded", - "test_sce_none_log_prob_expanded", - "test_sce_none_weights_expanded", - "test_sce_none_weights_log_prob_expanded", - "test_sce_sum_expanded", - "test_sce_sum_log_prob_expanded", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_simple_rnn_defaults", "test_simple_rnn_with_initial_bias", - "test_softmax_axis_0_expanded", - "test_softmax_axis_1_expanded", - "test_softmax_axis_2_expanded", - "test_softmax_default_axis_expanded", - "test_softmax_example_expanded", - "test_softmax_large_number_expanded", - "test_softmax_negative_axis_expanded", "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", "test_split_zero_size_splits", - "test_squeeze", - "test_squeeze_negative_axes", "test_strnormalizer_export_monday_casesensintive_lower", "test_strnormalizer_export_monday_casesensintive_nochangecase", "test_strnormalizer_export_monday_casesensintive_upper", @@ -5015,16 +5012,13 @@ def verify_eyelike(indata): "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_unique_sorted_with_axis", - "test_unique_sorted_with_axis_3d", - "test_unique_sorted_with_negative_axis", - "test_unsqueeze_axis_0", - "test_unsqueeze_axis_1", - "test_unsqueeze_axis_2", - "test_unsqueeze_negative_axes", + # These unsqueeze tests work, but take 2+ hrs to run "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", + "test_unique_sorted_with_axis", + "test_unique_sorted_with_axis_3d", + "test_unique_sorted_with_negative_axis", "test_upsample_nearest", ] @@ -5873,3 +5867,4 @@ def repeat(N, D): test_random_uniform() test_convinteger() test_batch_matmul() + test_global_lppool() diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 7042450400255..9f145b75a4054 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -40,9 +40,15 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def get_tvm_runtime(script_module, input_name, ishape): +def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False): input_shapes = [(input_name, ishape)] - mod, params = relay.frontend.from_pytorch(script_module, input_shapes) + mod, params = relay.frontend.from_pytorch( + script_module, input_shapes, keep_quantized_weight=keep_quantized_weight + ) + + if keep_quantized_weight: + for p in params.values(): + assert p.dtype in ["int8", "int32"] with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda @@ -92,6 +98,20 @@ def fuse_model(self): fuse_modules(self.conv, indices, inplace=True) +class ConvTranspose(nn.Module): + def __init__(self): + super().__init__() + layers = [nn.ConvTranspose2d(3, 32, 3, bias=True)] + self.conv = nn.Sequential(*layers) + self.quant_wrap = QuantWrapper(self.conv) + + def forward(self, x): + return self.quant_wrap(x) + + def fuse_model(self): + pass + + class Linear(nn.Module): def __init__(self, with_relu=False): super().__init__() @@ -270,6 +290,7 @@ def test_quantized_modules(): ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), + ("conv_transpose", imagenet_ishape, ConvTranspose(), False), ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), @@ -281,7 +302,15 @@ def test_quantized_modules(): raw_module.eval() inp = torch.rand(ishape) - quantize_model(raw_module, inp, per_channel=per_channel) + # quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0. + if module_name == "conv_transpose" and not is_version_greater_than("1.7.1"): + prev_engine = torch.backends.quantized.engine + torch.backends.quantized.engine = "qnnpack" + quantize_model(raw_module, inp, per_channel=per_channel) + torch.backends.quantized.engine = prev_engine + else: + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() with torch.no_grad(): @@ -308,6 +337,7 @@ def test_quantized_modules(): conv_bn_relu 0.3700896 0.010921672 0.7489366477964451 linear 0.15987062 0.009231662 0.794921875 linear_relu 0.14180502 0.0053220326 0.8828125 + conv_transpose 0.0033792555 4.4658788e-07 0.9998678439971806 conv_bn, per_channel 0.01654929 2.9486866e-06 0.9998218235127019 conv_bn_relu, per_channel 0.009089053 1.4926576e-06 0.9998357732732732 linear, per_channel 0.0 0.0 1.0 @@ -609,3 +639,36 @@ def test_qnn_mergecomposite(): input_name = "image" run_qnn_mergecomposite(script_module, input_name, inp.shape) + + +def test_keep_quantized_weight(): + qmodules = [] + + for per_channel in [False, True]: + qmodules += [ + ((1, 3, 224, 224), ConvBn(), per_channel), + ((16, 16), Linear(), per_channel), + ] + + for (ishape, raw_module, per_channel) in qmodules: + raw_module.eval() + inp = torch.rand(ishape) + + quantize_model(raw_module, inp, per_channel=per_channel) + script_module = torch.jit.trace(raw_module, inp).eval() + + input_name = "input" + + runtime = get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).numpy() + + runtime_int8_weight = get_tvm_runtime( + script_module, input_name, ishape, keep_quantized_weight=True + ) + runtime_int8_weight.set_input(input_name, inp.numpy().copy()) + runtime_int8_weight.run() + tvm_result_int8_weight = runtime_int8_weight.get_output(0).numpy() + + tvm.testing.assert_allclose(tvm_result, tvm_result_int8_weight) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c27469edf1d7e..9238acd5f049b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1963,8 +1963,9 @@ def _gen_rand_inputs(num_boxes): boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5 boxes[:, 2] += boxes[:, 0] boxes[:, 3] += boxes[:, 1] - scores = torch.from_numpy(np.random.uniform(-1, 1, size=(num_boxes,)).astype(np.float32)) - return boxes, scores + scores = np.linspace(0, 1, num=num_boxes).astype("float32") + np.random.shuffle(scores) + return boxes, torch.from_numpy(scores) targets = ["llvm", "cuda"] diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f2941030f0abe..c073681dcbf53 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1263,6 +1263,26 @@ def _test_transpose_conv( def test_forward_transpose_conv(): for quantized in [True, False]: for fp16_quantized in [True, False]: + # odd size input, padding VALID + _test_transpose_conv( + [1, 5, 6, 16], + [2, 2, 16, 16], + [1, 10, 12, 16], + [2, 2], + "VALID", + quantized, + fp16_quantized, + ) + # odd size input, padding SAME + _test_transpose_conv( + [1, 5, 6, 16], + [2, 2, 16, 16], + [1, 10, 12, 16], + [2, 2], + "SAME", + quantized, + fp16_quantized, + ) # kernel 3x3, padding VALID _test_transpose_conv( [4, 32, 32, 16], @@ -3266,6 +3286,7 @@ def _test_softmax(data): def test_forward_softmax(): """Softmax""" _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + _test_softmax(np.arange(6.0, dtype=np.float32).reshape((1, 2, 3))) ###################################################################### diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 490257ac66da1..746f595a4422f 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -26,7 +26,7 @@ import shutil import subprocess import tarfile -from typing import NamedTuple, Union, Optional, List, Dict +from typing import Any, NamedTuple, Union, Optional, List, Dict import pytest import numpy as np @@ -56,17 +56,53 @@ class AOTTestModel(NamedTuple): Dict of input names to value arrays outputs: List[np.array] Ordered list of output value arrays + output_tolerance: Optional[Union[int, float]] + Allowed tolerance of the output name: str Name to use for this model params: Optional[Dict[str, np.array]] Dict of parameter names to value arrays + extra_memory_in_bytes: int + Extra memory to allocate after planned memory """ module: tvm.IRModule inputs: Dict[str, np.array] outputs: List[np.array] + output_tolerance: Optional[Union[int, float]] = None name: str = "default" params: Optional[Dict[str, np.array]] = None + extra_memory_in_bytes: int = 0 + + +class AOTCompiledTestModel(NamedTuple): + """A compiled AOTTestModel with associated module + + Parameters + ---------- + model: AOTTestModel + Input model to be compiled + module: tvm.runtime.Module + The compiled Module for the associated AOTTestModel + """ + + model: AOTTestModel + executor_factory: tvm.relay.backend.executor_factory.AOTExecutorFactoryModule + + +class AOTDataLinkage(NamedTuple): + """A compiled AOTTestModel with associated module + + Parameters + ---------- + section: str + Named section to place data into + alignment: int + Section alignment + """ + + section: str + alignment: int class AOTTestRunner(NamedTuple): @@ -80,14 +116,17 @@ class AOTTestRunner(NamedTuple): Code to prepend to the main function includes: List[str] Additional includes required to run the AOT test runner - parameters: Map[str, str] + parameters: Dict[str, str] Additional parameters to pass to the make command + pass_config: Dict[str, Any] + Additional pass configuration when building the model """ makefile: str = "default" prologue: str = "" includes: List[str] = [] parameters: Dict[str, str] = {} + pass_config: Dict[str, Any] = {} AOT_DEFAULT_RUNNER = AOTTestRunner() @@ -225,11 +264,20 @@ def subprocess_log_output(cmd, cwd, logfile): return proc.wait() -def emit_main_prologue(main_file, custom_prologue, workspace_bytes): +# TODO: Move to linker script with list of symbols rather than coding into source +def emit_data_linkage(output_file, data_linkage): + if data_linkage is not None: + output_file.write( + f'__attribute__((section("{data_linkage.section}"), aligned({data_linkage.alignment}))) ' + ) + + +def emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage): # Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment. main_file.write( f"#define WORKSPACE_SIZE ({workspace_bytes} + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n" ) + emit_data_linkage(main_file, data_linkage) main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n") main_file.write("tvm_workspace_t app_workspace;\n") main_file.write( @@ -242,9 +290,14 @@ def emit_main_prologue(main_file, custom_prologue, workspace_bytes): return StackMemoryManager_Free(&app_workspace,ptr); } -void TVMPlatformAbort(tvm_crt_error_t code) { } +void TVMPlatformAbort(tvm_crt_error_t code) { exit(-1); } -void TVMLogf(const char* msg, ...) { } +void TVMLogf(const char* msg, ...) { + va_list args; + va_start(args, msg); + vfprintf(stdout, msg, args); + va_end(args); +} TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {} int main(){\n @@ -360,23 +413,30 @@ def fake_tensor(source, source_index, packed_index): main_file.write("\n") -def emit_main_compare(main_file, output_list, mod_name): +def emit_main_compare(main_file, output_list, output_tolerance, mod_name): num_outputs = len(output_list) actual_data_name = mangle_name(mod_name, "output_data") expected_data_name = mangle_name(mod_name, "expected_output_data") for i in range(0, num_outputs): is_float_dtype = output_list[i].dtype == "float32" - main_file.write(f"for (int i = 0; i<{actual_data_name}{i}_len; i++){{\n") + + comparison_function = "abs" + tolerance = output_tolerance or 0 if is_float_dtype: - main_file.write( - f'if (fabs({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > 0.001f){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' - ) - else: - main_file.write( - f'if ({actual_data_name}{i}[i]!={expected_data_name}{i}[i]){{\n\tprintf("{AOT_FAILURE_TOKEN}\\n");\n\treturn -1;}}\n' - ) - main_file.write("}\n") + comparison_function = "fabs" + tolerance = output_tolerance or 0.001 + + main_file.write( + f""" + for (int i = 0; i<{actual_data_name}{i}_len; i++) {{ + if ({comparison_function}({actual_data_name}{i}[i]-{expected_data_name}{i}[i]) > {tolerance}) {{ + printf("{AOT_FAILURE_TOKEN}\\n"); + return -1; + }} + }} + """ + ) def emit_main_init_memory_manager(main_file): @@ -392,6 +452,8 @@ def emit_main_epilogue(main_file): def emit_main_common_includes(main_file, custom_includes): main_file.write("#include \n") + main_file.write("#include \n") + main_file.write("#include \n") main_file.write("#include \n") main_file.write('#include "tvm/runtime/c_runtime_api.h"\n') main_file.write('#include "tvm/runtime/crt/stack_allocator.h"\n') @@ -404,7 +466,14 @@ def emit_main_micro_include(main_file, mod_name): def create_main( - test_name, models, output_path, custom_includes, custom_prologue, interface_api, workspace_bytes + test_name, + models, + output_path, + custom_includes, + custom_prologue, + data_linkage, + interface_api, + workspace_bytes, ): file_path = pathlib.Path(f"{output_path}/" + test_name).resolve() # create header file @@ -418,7 +487,7 @@ def create_main( for model in models: emit_main_data(main_file, model.inputs, model.outputs, model.name) - emit_main_prologue(main_file, custom_prologue, workspace_bytes) + emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage) emit_main_init_memory_manager(main_file) if interface_api == "c": @@ -432,11 +501,11 @@ def create_main( emit_main_packed_call(main_file, model.inputs, model.outputs, model.name) for model in models: - emit_main_compare(main_file, model.outputs, model.name) + emit_main_compare(main_file, model.outputs, model.output_tolerance, model.name) emit_main_epilogue(main_file) -def create_header_file(tensor_name, npy_data, output_path): +def create_header_file(tensor_name, npy_data, output_path, data_linkage): """ This method generates a header file containing the data contained in the numpy array provided. It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application. @@ -450,6 +519,8 @@ def create_header_file(tensor_name, npy_data, output_path): header_file.write("#include \n") header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n") + emit_data_linkage(header_file, data_linkage) + if npy_data.dtype == "int8": header_file.write(f"int8_t {tensor_name}[] =") elif npy_data.dtype == "int32": @@ -471,37 +542,67 @@ def extract_main_workspace_size_bytes(extract_dir): return metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"] -def compile_and_run( +def compile_models( models: Union[List[AOTTestModel], AOTTestModel], - runner: AOTTestRunner, - interface_api, - use_unpacked_api, - debug_calculated_workspaces=False, - workspace_byte_alignment=8, - enable_op_fusion=True, -): + interface_api: str, + use_unpacked_api: bool, + workspace_byte_alignment: int = 8, + enable_op_fusion: bool = True, + pass_config: Dict[str, Any] = None, +) -> List[AOTCompiledTestModel]: """ - This method verifies the generated source + This method generates runtime.Modules for the tests """ - base_target = "c -runtime=c --link-params --executor=aot" - extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" - target = f"{base_target} {extra_target}" - cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " - if not isinstance(models, list): models = [models] - # The calculated workspaces will not account for stack allocator tags used for debugging - if debug_calculated_workspaces: - cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " + base_target = "c -runtime=c --link-params --executor=aot" + extra_target = f"--workspace-byte-alignment={workspace_byte_alignment} --interface-api={interface_api} --unpacked-api={int(use_unpacked_api)}" + target = f"{base_target} {extra_target}" config = {"tir.disable_vectorize": True} + if pass_config: + config = {**config, **pass_config} if not enable_op_fusion: config["relay.FuseOps.max_depth"] = 1 + compiled_mods = list() + for model in models: + with tvm.transform.PassContext(opt_level=3, config=config): + executor_factory = tvm.relay.build( + model.module, + target, + target_host=target, + params=model.params, + mod_name=model.name, + ) + compiled_mods.append( + AOTCompiledTestModel(model=model, executor_factory=executor_factory) + ) + return compiled_mods + + +def run_and_check( + models: List[AOTCompiledTestModel], + runner: AOTTestRunner, + interface_api: str, + debug_calculated_workspaces=False, + workspace_byte_alignment=8, + data_linkage: AOTDataLinkage = None, +): + """ + This method uses the original test data and compiled runtime.Modules + to run in the test runner to verify the results. + """ + tmp_path = utils.tempdir() tmp_dir = tmp_path.temp_dir + cflags = f"-DTVM_RUNTIME_ALLOC_ALIGNMENT_BYTES={workspace_byte_alignment} " + # The calculated workspaces will not account for stack allocator tags used for debugging + if debug_calculated_workspaces: + cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK " + base_path = os.path.join(tmp_dir, "test") build_path = os.path.join(base_path, "build") os.makedirs(build_path, exist_ok=True) @@ -515,21 +616,14 @@ def compile_and_run( ) workspace_bytes = 0 - for model in models: - with tvm.transform.PassContext(opt_level=3, config=config): - lib = tvm.relay.build( - model.module, - target, - target_host=target, - params=model.params, - mod_name=model.name, - ) - + for compiled_model in models: + model = compiled_model.model tar_file = os.path.join(base_path, f"{model.name}.tar") - export_model_library_format(lib, tar_file) + export_model_library_format(compiled_model.executor_factory, tar_file) t = tarfile.open(tar_file) t.extractall(base_path) + workspace_bytes += model.extra_memory_in_bytes workspace_bytes += extract_main_workspace_size_bytes(base_path) for key in model.inputs: @@ -538,6 +632,7 @@ def compile_and_run( f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], include_path, + data_linkage, ) for i in range(len(model.outputs)): @@ -545,19 +640,22 @@ def compile_and_run( (f'{mangle_name(model.name,"output_data")}{i}'), np.zeros(model.outputs[i].shape, model.outputs[i].dtype), include_path, + data_linkage, ) create_header_file( (f'{mangle_name(model.name, "expected_output_data")}{i}'), model.outputs[i], include_path, + data_linkage, ) create_main( "test.c", - models, + [compiled_model.model for compiled_model in models], build_path, runner.includes, runner.prologue, + data_linkage, interface_api, workspace_bytes, ) @@ -592,6 +690,35 @@ def compile_and_run( assert AOT_SUCCESS_TOKEN in run_log.read() +def compile_and_run( + models: Union[List[AOTTestModel], AOTTestModel], + runner: AOTTestRunner, + interface_api: str, + use_unpacked_api: bool, + debug_calculated_workspaces: bool = False, + workspace_byte_alignment: int = 8, + enable_op_fusion: bool = True, + data_linkage: AOTDataLinkage = None, +): + """This is a wrapper API to compile and run models as test for AoT""" + compiled_test_mods = compile_models( + models=models, + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + workspace_byte_alignment=workspace_byte_alignment, + enable_op_fusion=enable_op_fusion, + pass_config=runner.pass_config, + ) + run_and_check( + models=compiled_test_mods, + runner=runner, + interface_api=interface_api, + debug_calculated_workspaces=debug_calculated_workspaces, + workspace_byte_alignment=workspace_byte_alignment, + data_linkage=data_linkage, + ) + + def generate_ref_data(mod, input_data, params=None, target="llvm"): """Generate reference data through executing the relay module""" compile_engine.get().clear() diff --git a/tests/python/relay/aot/corstone300.ld b/tests/python/relay/aot/corstone300.ld index 4a6b22480d9f8..9534b869f6e6c 100644 --- a/tests/python/relay/aot/corstone300.ld +++ b/tests/python/relay/aot/corstone300.ld @@ -257,6 +257,14 @@ SECTIONS __bss_end__ = .; } > DTCM AT > DTCM + .ddr : + { + . = ALIGN(4); + . = ALIGN(16); + *(ethosu_scratch) + . = ALIGN (16); + } > DDR + .data_sram : { . = ALIGN(16); diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index 3a946f2cd8768..8d03ccc5b5f40 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -28,9 +28,11 @@ endif ARM_CPU=ARMCM55 DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core ETHOSU_PATH=/opt/arm/ethosu +DRIVER_PATH=${ETHOSU_PATH}/core_driver CMSIS_PATH=${ETHOSU_PATH}/cmsis PLATFORM_PATH=${ETHOSU_PATH}/core_platform/targets/corstone-300 PKG_COMPILE_OPTS = -g -Wall -O2 -Wno-incompatible-pointer-types -Wno-format -mcpu=cortex-m55 -mthumb -mfloat-abi=hard -std=gnu99 +CMAKE = /opt/arm/cmake/bin/cmake CC = arm-none-eabi-gcc AR = arm-none-eabi-ar RANLIB = arm-none-eabi-ranlib @@ -40,11 +42,15 @@ PKG_CFLAGS = ${PKG_COMPILE_OPTS} \ -I$(build_dir)/../include \ -I$(CODEGEN_ROOT)/host/include \ -I${PLATFORM_PATH} \ + -I${DRIVER_PATH}/include \ -I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \ -I${CMSIS_PATH}/CMSIS/Core/Include \ -I${CMSIS_PATH}/CMSIS/NN/Include \ -I${CMSIS_PATH}/CMSIS/DSP/Include \ - -isystem$(STANDALONE_CRT_DIR)/include \ + -isystem$(STANDALONE_CRT_DIR)/include +DRIVER_CMAKE_FLAGS = -DCMAKE_TOOLCHAIN_FILE=$(ETHOSU_TEST_ROOT)/arm-none-eabi-gcc.cmake \ + -DETHOSU_LOG_SEVERITY=debug \ + -DCMAKE_SYSTEM_PROCESSOR=cortex-m55 PKG_LDFLAGS = -lm -specs=nosys.specs -static -T ${AOT_TEST_ROOT}/corstone300.ld @@ -61,6 +67,11 @@ CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c CMSIS_NN_SRCS = $(shell find ${CMSIS_PATH}/CMSIS/NN/Source/*/*.c) UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) +ifdef ETHOSU_TEST_ROOT +ETHOSU_ARCHIVE=${build_dir}/ethosu_core_driver/libethosu_core_driver.a +ETHOSU_INCLUDE=-I$(ETHOSU_TEST_ROOT) +endif + aot_test_runner: $(build_dir)/aot_test_runner $(build_dir)/stack_allocator.o: $(TVM_ROOT)/src/runtime/crt/memory/stack_allocator.c @@ -94,9 +105,14 @@ ${build_dir}/libuart.a: $(UART_SRCS) $(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath $(build_dir))/libuart/*.o $(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a) -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a +${build_dir}/ethosu_core_driver/libethosu_core_driver.a: + $(QUIET)mkdir -p $(@D) + $(QUIET)cd $(DRIVER_PATH) && $(CMAKE) -B $(abspath $(build_dir)/ethosu_core_driver) $(DRIVER_CMAKE_FLAGS) + $(QUIET)cd $(abspath $(build_dir)/ethosu_core_driver) && $(MAKE) + +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o ${build_dir}/libcmsis_startup.a ${build_dir}/libcmsis_nn.a ${build_dir}/libuart.a $(build_dir)/libcodegen.a $(ETHOSU_ARCHIVE) $(QUIET)mkdir -p $(@D) - $(QUIET)$(CC) $(PKG_CFLAGS) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) + $(QUIET)$(CC) $(PKG_CFLAGS) $(ETHOSU_INCLUDE) -o $@ -Wl,--whole-archive $^ -Wl,--no-whole-archive $(PKG_LDFLAGS) clean: $(QUIET)rm -rf $(build_dir)/crt @@ -109,6 +125,7 @@ run: $(build_dir)/aot_test_runner -C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ + -C ethosu.extra_args="--fast" \ -C ethosu.num_macs=$(NPU_VARIANT) $(build_dir)/aot_test_runner .SUFFIXES: diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index e117302d0ed84..d90c4217c4c3d 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -26,12 +26,14 @@ from tvm.ir.module import IRModule from tvm.relay import testing, transform from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, generate_ref_data, convert_to_relay, compile_and_run, + compile_models, parametrize_aot_options, ) @@ -297,13 +299,22 @@ def test_mobilenet(debug_calculated_workspaces, workspace_byte_alignment): interface_api = "c" test_runner = AOT_DEFAULT_RUNNER + # TODO(@Mousius) - Enable memory planning to take into account debug information + debugging_memory_overhead = 1024 * 1024 + mod, params = testing.mobilenet.get_workload(batch_size=1) data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} output_list = generate_ref_data(mod, inputs, params) compile_and_run( - AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + AOTTestModel( + module=mod, + inputs=inputs, + outputs=output_list, + params=params, + extra_memory_in_bytes=debugging_memory_overhead, + ), test_runner, interface_api, use_unpacked_api, @@ -312,8 +323,58 @@ def test_mobilenet(debug_calculated_workspaces, workspace_byte_alignment): ) -def test_byoc_microtvm(): - """This is a simple test case to check BYOC capabilities of AOT""" +@pytest.mark.parametrize("merge_compiler_regions", [False, True]) +def test_byoc_microtvm(merge_compiler_regions): + """This is a simple test to check BYOC capabilities of AOT - with and without merging compiler regions to test for https://github.com/apache/tvm/issues/9036""" + use_unpacked_api = False + interface_api = "packed" + test_runner = AOT_DEFAULT_RUNNER + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + + # z0 = x + w0 + x_ = compiler_begin(x, "ccompiler") + w0_ = compiler_begin(w0, "ccompiler") + z0_ = relay.add(x_, w0_) + z0 = compiler_end(z0_, "ccompiler") + + # z1 = z0 + w1 + z0__ = compiler_begin(z0, "ccompiler") + w1_ = compiler_begin(w1, "ccompiler") + z1_ = relay.add(z0__, w1_) + z1 = compiler_end(z1_, "ccompiler") + + # z2 = z0 + z1 + z2 = relay.add(z0, z1) + + f = relay.Function([x, w0, w1], z2) + mod = tvm.IRModule() + mod["main"] = f + + if merge_compiler_regions: + mod = transform.MergeCompilerRegions()(mod) + + mod = transform.PartitionGraph("mod_name")(mod) + mod = transform.InferType()(mod) + + x_data = [("x", np.random.rand(10, 10).astype("float32"))] + w_data = [("w{}".format(i), np.random.rand(10, 10).astype("float32")) for i in range(2)] + + map_inputs = OrderedDict(x_data + w_data) + output_list = generate_ref_data(mod, map_inputs) + compile_and_run( + AOTTestModel(name="my_mod", module=mod, inputs=map_inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@pytest.mark.parametrize("merge_compiler_regions", [False, True]) +def test_byoc_microtvm_multiple_subgraphs(merge_compiler_regions): + """This is a test case to check BYOC capabilities of AOT with multiple sub graphs""" use_unpacked_api = False interface_api = "packed" test_runner = AOT_DEFAULT_RUNNER @@ -347,6 +408,9 @@ def test_byoc_microtvm(): ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) + if merge_compiler_regions: + mod = transform.MergeCompilerRegions()(mod) + mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) mod = tvm.relay.transform.InferType()(mod) @@ -589,5 +653,45 @@ def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_work ) +def test_aot_codegen_backend_alloc_workspace_calls(): + """This test checks whether AoT lowering creates TVMBackendAllocWorkspace calls""" + + # The %data and %weight shapes in the following primitive Relay should create + # small tensors that would get lowered to stack allocations in the CPU PrimFuncs. + # However, the AoT executor codegen should retain them as TVMBAW calls + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), float32], src_layout="OIHW", dst_layout="OIHW4i4o") -> Tensor[(1, 4, 4, 4), float32] { + %0 = fn (%p02: Tensor[(1, 4, 4, 4), float32], Primitive=1, hash="9332b3872fb5292c", src_layout="NCHW", dst_layout="NCHW4c") -> Tensor[(1, 1, 4, 4, 4), float32] { + layout_transform(%p02, src_layout="NCHW", dst_layout="NCHW4c") /* ty=Tensor[(1, 1, 4, 4, 4), float32] */ + }; + %1 = fn (%p03: Tensor[(4, 4, 3, 3), float32], Primitive=1, hash="9f0b2b8a24a4dab3", src_layout="OIHW", dst_layout="OIHW4i4o") -> Tensor[(1, 1, 3, 3, 4, 4), float32] { + layout_transform(%p03, src_layout="OIHW", dst_layout="OIHW4i4o") /* ty=Tensor[(1, 1, 3, 3, 4, 4), float32] */ + }; + %2 = %0(%data) /* ty=Tensor[(1, 1, 4, 4, 4), float32] */; + %3 = %1(%weight) /* ty=Tensor[(1, 1, 3, 3, 4, 4), float32] */; + %4 = fn (%p01: Tensor[(1, 1, 4, 4, 4), float32], %p1: Tensor[(1, 1, 3, 3, 4, 4), float32], out_layout="NCHW4c", kernel_layout="OIHW4i4o", Primitive=1, data_layout="NCHW4c") -> Tensor[(1, 1, 4, 4, 4), float32] { + nn.contrib_conv2d_NCHWc(%p01, %p1, padding=[1, 1, 1, 1], channels=4, kernel_size=[3, 3], data_layout="NCHW4c", kernel_layout="OIHW4i4o", out_layout="NCHW4c") /* ty=Tensor[(1, 1, 4, 4, 4), float32] */ + }; + %5 = %4(%2, %3) /* ty=Tensor[(1, 1, 4, 4, 4), float32] */; + %6 = fn (%p0: Tensor[(1, 1, 4, 4, 4), float32], Primitive=1, src_layout="NCHW4c", dst_layout="NCHW") -> Tensor[(1, 4, 4, 4), float32] { + layout_transform(%p0, src_layout="NCHW4c", dst_layout="NCHW") /* ty=Tensor[(1, 4, 4, 4), float32] */ + }; + %6(%5) /* ty=Tensor[(1, 4, 4, 4), float32] */ + } + """ + ) + compiled_test_mods = compile_models( + models=AOTTestModel(module=relay_mod, inputs=None, outputs=None), + interface_api="c", + use_unpacked_api=True, + ) + source = compiled_test_mods[0].executor_factory.lib.imported_modules[0].get_source() + # There should be three allocates created for three primitive relay function + # calls in the main for the above relay snippet. + assert source.count("TVMBackendAllocWorkspace") == 3 + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index d2ad5a47f15b5..22583eda4a408 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -19,11 +19,10 @@ import numpy as np import pytest import tvm -from tvm import te -from tvm import relay +import tvm.testing +from tvm import relay, te from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type -import tvm.testing def verify_func(func, data, ref_res, target_device=tvm.testing.enabled_targets()): @@ -93,6 +92,51 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +def test_squeeze(): + def verify_squeeze(shape, dtype, axis): + x = relay.var("x", relay.TensorType(shape, dtype)) + assert axis is not None + np_axis = tuple(axis) + axis = relay.var("axis", relay.TensorType([len(axis)], "int64")) + squeeze = relay.squeeze(x, axis=axis) + func = relay.Function([x, axis], squeeze) + x_data = np.random.random_sample(shape).astype(dtype) + ref_res = np.squeeze(x_data, axis=np_axis) + verify_func(func, [x_data, np.array(np_axis).astype("int64")], ref_res) + + verify_squeeze((1, 3, 1), "float32", [0]) + verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2]) + + +@tvm.testing.uses_gpu +def test_dyn_expand_dims(): + def verify_expand_dims( + dshape, dtype, oshape, axis, num_newaxis, target_device=tvm.testing.enabled_targets() + ): + # Use 1 to avoid issues with invalid buffer sizes + x = relay.Var("x", relay.TensorType(dshape, dtype)) + y = relay.var("axis", shape=[], dtype="int64") + z = relay.expand_dims(x, axis=y, num_newaxis=num_newaxis) + func = relay.Function([x, y], z) + + data_np = np.random.uniform(size=dshape).astype(dtype) + axis_np = np.array(axis).astype("int64") + ref_res = data_np.reshape(oshape) + verify_func(func, [data_np, axis_np], ref_res, target_device=target_device) + + for dtype in ["float16", "float32"]: + verify_expand_dims((2, 2), dtype, (2, 2, 1), 2, 1) + verify_expand_dims((2, 2), dtype, (2, 1, 2), 1, 1) + verify_expand_dims((2, 2), dtype, (1, 2, 2), 0, 1) + + # TODO (AndrewZhaoLuo): investigate why runtimes in non-llvm are extremely slow + # for multiple new axis + llvm_target_only = [x for x in tvm.testing.enabled_targets() if "llvm" in x] + verify_expand_dims((2, 2), dtype, (2, 2, 1, 1), 2, 2, target_device=llvm_target_only) + verify_expand_dims((2, 2), dtype, (2, 1, 1, 1, 2), 1, 3, target_device=llvm_target_only) + verify_expand_dims((2, 2), dtype, (1, 1, 1, 1, 2, 2), 0, 4, target_device=llvm_target_only) + + @tvm.testing.uses_gpu def test_dyn_tile(): def verify_tile(dshape, reps): diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py new file mode 100644 index 0000000000000..58e559eb96809 --- /dev/null +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for annotations.""" +import tvm +from tvm import relay +import pytest + + +def test_on_device_via_string(): + x = relay.Var("x") + call = relay.annotation.on_device(x, "cuda") + assert isinstance(call, relay.Call) + assert len(call.args) == 1 + assert call.args[0] == x + assert call.attrs.device_type == 2 # ie kDLCUDA + assert not call.attrs.is_fixed + + +def test_on_device_via_device(): + x = relay.Var("x") + call = relay.annotation.on_device(x, tvm.device("llvm")) + assert call.attrs.device_type == 1 # ie kDLCPU + + +def test_on_device_invalid_device(): + x = relay.Var("x") + pytest.raises(ValueError, lambda: relay.annotation.on_device(x, "bogus")) + + +def test_on_device_is_fixed(): + x = relay.Var("x") + call = relay.annotation.on_device(x, "cuda", True) + assert call.attrs.device_type == 2 + assert call.attrs.is_fixed + + +def test_function_on_device(): + x = relay.Var("x") + y = relay.Var("y") + f = relay.Function([x, y], relay.add(x, y)) + func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") + assert isinstance(func, relay.Function) + assert len(func.attrs["param_device_types"]) == 2 + assert func.attrs["param_device_types"][0] == 1 # ie kDLCPU + assert func.attrs["param_device_types"][1] == 2 # ie kDLCUDA + assert func.attrs["result_device_type"] == 2 # ie KDLCUDA + + +if __name__ == "__main__": + test_on_device_via_string() + test_on_device_via_device() + test_on_device_invalid_device() + test_on_device_is_fixed() + test_function_on_device() diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index ad5f2aa9d4faf..41c113684f0a4 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -213,6 +213,39 @@ def constant_updater(expr, symbol): tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") +@pytest.mark.skipif( + not tvm.get_global_func("relay.ext.dnnl", True), + reason="skip because DNNL codegen is not available", +) +@parametrize_external_json_codegen_checks +def test_extern_dnnl_padding(check_result): + dtype = "float32" + ishape = (1, 1, 99, 12) + w1shape = (54, 1, 3, 3) + data0 = relay.var("data0", shape=(ishape), dtype=dtype) + weight0 = relay.var("weight0", shape=(w1shape), dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), strides=(2, 2), padding=(1, 0, 1, 1)) + f = relay.Function([data0, weight0], out) + ref_mod = tvm.IRModule() + ref_mod["main"] = f + + data1 = relay.var("data0", shape=(ishape), dtype=dtype) + weight1 = relay.var("weight0", shape=(w1shape), dtype=dtype) + f = set_external_func_attr(f, "dnnl", "dnnl_0") + call = relay.Call(f, [data1, weight1]) + mod = tvm.IRModule.from_expr(call) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_res = relay.create_executor("graph", mod=ref_mod, device=tvm.cpu()).evaluate()( + i_data, w_data + ) + check_result( + mod, {"data0": i_data, "weight0": w_data}, (1, 54, 50, 6), ref_res.numpy(), tol=1e-5 + ) + + @pytest.mark.skipif( not tvm.get_global_func("relay.ext.dnnl", True), reason="skip because DNNL codegen is not available", diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 099e127aeba99..fdbd3924ffb7f 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -23,7 +23,6 @@ from numpy import isclose from typing import Union - SEMVER = '#[version = "0.0.5"]\n' BINARY_OPS = { @@ -967,6 +966,30 @@ def test_func_attrs(): assert_parses_as(func.astext(), func) +def test_init_module_and_metatable(): + init_metatable = {"relay.Constant": [relay.const(np.random.rand(2, 3), dtype="float32")]} + init_module = tvm.parser.fromtext( + SEMVER + + """ + def @f(%y : Tensor[(2, 3), float32]) -> Tensor[(2, 3), float32] { + negative(%y) + } + """, + ) + mod = tvm.parser.parse( + SEMVER + + """ + def @main(%x: Tensor[(2, 3), float32]) { + add(@f(%x), meta[relay.Constant][0]) + } + """, + "from_string", + init_module, + init_metatable, + ) + roundtrip(mod) + + if __name__ == "__main__": import sys diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index f5674dbf5fb39..ca792204c835e 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -96,12 +96,14 @@ def test_conv2d(): def conv2d_direct(): dtype = "float32" - ishape = (1, 32, 14, 14) - w1shape = (32, 32, 3, 3) + ishape = (1, 1, 99, 12) + w1shape = (54, 1, 3, 3) data0 = relay.var("data", shape=ishape, dtype=dtype) weight0 = relay.var("weight", shape=w1shape, dtype=dtype) - out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) + out = relay.nn.conv2d( + data0, weight0, kernel_size=(3, 3), strides=(2, 2), padding=(1, 0, 1, 1) + ) func = relay.Function([data0, weight0], out) func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") @@ -118,7 +120,9 @@ def conv2d_direct(): data0 = relay.var("data", shape=ishape, dtype=dtype) weight0 = relay.var("weight", shape=w1shape, dtype=dtype) - out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1)) + out = relay.nn.conv2d( + data0, weight0, kernel_size=(3, 3), strides=(2, 2), padding=(1, 0, 1, 1) + ) main_f = relay.Function([data0, weight0], out) ref_mod = tvm.IRModule() ref_mod["main"] = main_f @@ -127,7 +131,7 @@ def conv2d_direct(): i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype) - return mod, ref_mod, {"data": i_data, "weight": w1_data}, (1, 32, 14, 14) + return mod, ref_mod, {"data": i_data, "weight": w1_data}, (1, 54, 50, 6) def group_conv2d(): dtype = "float32" @@ -212,6 +216,50 @@ def gen_add(): check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) +def test_multiply(): + """Test a subgraph with a single add operator.""" + if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): + print("skip because DNNL codegen is not available") + return + + dtype = "float32" + shape = (10, 10) + + def gen_multiply(): + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + out = relay.multiply(data0, data1) + + func = relay.Function([data0, data1], out) + func = set_func_attr(func, "dnnl", "tvmgen_default_dnnl_0") + glb_var = relay.GlobalVar("tvmgen_default_dnnl_0") + mod = tvm.IRModule() + mod[glb_var] = func + mod = transform.InferType()(mod) + + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + main_f = relay.Function([data0, data1], glb_var(data0, data1)) + mod["main"] = main_f + mod = transform.InferType()(mod) + + data0 = relay.var("data0", shape=shape, dtype=dtype) + data1 = relay.var("data1", shape=shape, dtype=dtype) + out = relay.multiply(data0, data1) + main_f = relay.Function([data0, data1], out) + ref_mod = tvm.IRModule() + ref_mod["main"] = main_f + ref_mod = transform.InferType()(ref_mod) + + return mod, ref_mod + + mod, ref_mod = gen_multiply() + + data0 = np.random.uniform(0, 1, shape).astype(dtype) + data1 = np.random.uniform(0, 1, shape).astype(dtype) + check_result(mod, ref_mod, {"data0": data0, "data1": data1}, shape, tol=1e-5) + + def test_relu(): """Test a subgraph with a single ReLU operator.""" if not tvm.get_global_func("runtime.DNNLJSONRuntimeCreate", True): @@ -221,7 +269,7 @@ def test_relu(): dtype = "float32" shape = (1, 32, 14, 14) - def gen_relu(): + def gen_relu(shape): data0 = relay.var("data0", shape=shape, dtype=dtype) out = relay.nn.relu(data0) @@ -246,18 +294,22 @@ def gen_relu(): return mod, ref_mod - mod, ref_mod = gen_relu() + def check(shape): + mod, ref_mod = gen_relu(shape) + + data0 = np.random.uniform(-1, 1, shape).astype(dtype) + check_result( + mod, + ref_mod, + { + "data0": data0, + }, + shape, + tol=1e-5, + ) - data0 = np.random.uniform(-1, 1, shape).astype(dtype) - check_result( - mod, - ref_mod, - { - "data0": data0, - }, - (1, 32, 14, 14), - tol=1e-5, - ) + check(shape=(1, 32, 14, 14)) + check(shape=(1, 32)) def test_dense(): @@ -664,6 +716,7 @@ def test_partial_constant(): if __name__ == "__main__": test_conv2d() test_add() + test_multiply() test_relu() test_dense() test_bn() diff --git a/tests/python/relay/test_name_transforms.py b/tests/python/relay/test_name_transforms.py new file mode 100644 index 0000000000000..c4a7d6c4477c0 --- /dev/null +++ b/tests/python/relay/test_name_transforms.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License" you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm import TVMError +from tvm.relay.backend.name_transforms import ( + to_c_function_style, + to_c_variable_style, + prefix_name, + prefix_generated_name, + sanitize_name, +) +import pytest + + +def test_to_c_function_style(): + assert to_c_function_style("TVM_Woof") == "TVMWoof" + assert to_c_function_style("TVM_woof") == "TVMWoof" + assert to_c_function_style("TVM_woof_woof") == "TVMWoofWoof" + assert to_c_function_style("TVMGen_woof_woof") == "TVMGenWoofWoof" + + # Incorrect prefix + with pytest.raises(TVMError, match="Function not TVM prefixed"): + to_c_function_style("Cake_Bakery") + with pytest.raises(TVMError, match="Function name is empty"): + to_c_function_style("") + + +def test_to_c_variable_style(): + assert to_c_variable_style("TVM_Woof") == "tvm_woof" + assert to_c_variable_style("TVM_woof") == "tvm_woof" + assert to_c_variable_style("TVM_woof_Woof") == "tvm_woof_woof" + + # Incorrect prefix + with pytest.raises(TVMError, match="Variable not TVM prefixed"): + to_c_variable_style("Cake_Bakery") + with pytest.raises(TVMError, match="Variable name is empty"): + to_c_variable_style("") + + +def test_prefix_name(): + assert prefix_name("Woof") == "TVM_Woof" + assert prefix_name(["Woof"]) == "TVM_Woof" + assert prefix_name(["woof"]) == "TVM_woof" + assert prefix_name(["woof", "moo"]) == "TVM_woof_moo" + + with pytest.raises(TVMError, match="Name is empty"): + prefix_name("") + with pytest.raises(TVMError, match="Name segments empty"): + prefix_name([]) + with pytest.raises(TVMError, match="Name segment is empty"): + prefix_name([""]) + + +def test_prefix_generated_name(): + assert prefix_generated_name("Woof") == "TVMGen_Woof" + assert prefix_generated_name(["Woof"]) == "TVMGen_Woof" + assert prefix_generated_name(["Woof"]) == "TVMGen_Woof" + assert prefix_generated_name(["woof"]) == "TVMGen_woof" + assert prefix_generated_name(["woof", "moo"]) == "TVMGen_woof_moo" + + with pytest.raises(TVMError, match="Name is empty"): + prefix_generated_name("") + with pytest.raises(TVMError, match="Name segments empty"): + prefix_generated_name([]) + with pytest.raises(TVMError, match="Name segment is empty"): + prefix_generated_name([""]) + + +def test_sanitize_name(): + assert sanitize_name("+_+ ") == "_" + assert sanitize_name("input+") == "input_" + assert sanitize_name("input-") == "input_" + assert sanitize_name("input++") == "input_" + assert sanitize_name("woof:1") == "woof_1" + + with pytest.raises(TVMError, match="Name is empty"): + sanitize_name("") + + +def test_combined_logic(): + assert ( + to_c_function_style(prefix_name(["Device", "target", "Invoke"])) == "TVMDeviceTargetInvoke" + ) + assert to_c_function_style(prefix_generated_name(["model", "Run"])) == "TVMGenModelRun" + assert to_c_variable_style(prefix_name(["Device", "target", "t"])) == "tvm_device_target_t" + assert ( + to_c_variable_style(prefix_generated_name(["model", "Devices"])) == "tvmgen_model_devices" + ) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 44f211dd9f8ac..da2877063c45f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1587,156 +1587,143 @@ def test_upsampling3d(): _test_upsampling3d("NDHWC", "trilinear", "align_corners") -@tvm.testing.uses_gpu -def test_conv2d_int8_intrinsics(): - def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): +@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8") +class TestConv2DInt8Intrinsics: + supported_targets = [ + "llvm -mcpu=nehalem", + "llvm -mcpu=core-avx2", + "llvm -mcpu=skylake-avx512", + "llvm -mcpu=cascadelake", + ] + + unsupported_targets = [ + "llvm -mcpu=x86-64", + ] + + data_layout, kernel_layout = tvm.testing.parameters( + ("NCHW", "OIHW"), + # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout. + # Re-enable this after adding conv2d_NCHWc_int8 support for NHWC. + # ("NHWC", "HWIO"), + ) + + input_channels, output_channels = tvm.testing.parameters( + # Sweep the input channels to check int8 robustness + # Input channels should be a multiple of 4 internally. + (1, 16), + (4, 16), + (6, 16), + # Sweep the output channels to check int8 robustness + # Output channels should be a multiple of 16 internally. + (8, 4), + (8, 16), + (8, 20), + # Check that both non-divisible oc and ic work + (17, 29), + ) + + @tvm.testing.fixture + def fast_int8_intrinsic(self, target): + if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: + return "pmaddubs" + elif "cascadelake" in target: + return "vpdpbusd" + else: + assert False, "Target should be Skylake or Cascadelake" + + @tvm.testing.fixture + def assembly( + self, + target, + dtypes, + input_channels, + output_channels, + data_layout, + kernel_layout, + ): input_dtype, weight_dtype, output_dtype = dtypes - n, h, w, ch, cw = 1, 64, 64, 3, 3 + image_size = (64, 64) + kernel_size = (3, 3) + batch_size = 1 + + h, w = image_size + if data_layout == "NCHW": - data_shape = (n, ic, h, w) - x = relay.var("x", relay.TensorType(data_shape, input_dtype)) + data_shape = (batch_size, input_channels, *image_size) elif data_layout == "NHWC": - data_shape = (n, h, w, ic) - x = relay.var("x", relay.TensorType(data_shape, input_dtype)) + data_shape = (batch_size, *image_size, input_channels) else: - raise ValueError("Not supported") + raise ValueError(f"Unsupported data layout: {data_layout}") + x = relay.var("x", relay.TensorType(data_shape, input_dtype)) if kernel_layout == "OIHW": - kernel_shape = (oc, ic, ch, cw) + kernel_shape = (output_channels, input_channels, *kernel_size) elif kernel_layout == "HWIO": - kernel_shape = (ch, cw, ic, oc) + kernel_shape = (*kernel_size, input_channels, output_channels) else: raise ValueError("Not supported") - weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype)) + y = relay.nn.conv2d( x, weight, - kernel_size=(ch, cw), - channels=oc, + kernel_size=kernel_size, + channels=output_channels, padding=(0, 0, 0, 1), dilation=(1, 1), data_layout=data_layout, kernel_layout=kernel_layout, out_dtype=output_dtype, ) + func = relay.Function([x, weight], y) + wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) - assembly = lib.get_source("asm") - return assembly - - def _has_fast_int8_instructions(asm, target): - if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: - return "pmaddubs" in asm - elif "cascadelake" in target: - return "vpdpbusd" in asm - else: - assert False, "Target should be Skylake or Cascadelake" - - # TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout. - # Re-enable this after adding conv2d_NCHWc_int8 support for NHWC. - - # compile conv2d for x86 (SSE3/AVX2/AVX512/VNNI capable) and test assembly contains *pmadd* instructions - targets = [ - "llvm -mcpu=nehalem", - "llvm -mcpu=core-avx2", - "llvm -mcpu=skylake-avx512", - "llvm -mcpu=cascadelake", - ] - llvm_version = tvm.target.codegen.llvm_version_major() - for target in targets: - if tvm.testing.device_enabled(target) and llvm_version >= 8: - dtypes = ("uint8", "int8", "int32") - # Sweep the input channels to check int8 robustness - # Input channels should be a multiple of 4 internally. - for ic in [1, 4, 6]: - asm = _compile( - ic=ic, - oc=16, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=dtypes, - ) - assert _has_fast_int8_instructions(asm, target) - - # for ic in [1, 4, 6]: - # asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC", - # kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Sweep the output channels to check int8 robustness - # Output channels should be a multiple of 16 internally. - for oc in [4, 16, 20]: - asm = _compile( - ic=8, - oc=oc, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=dtypes, - ) - assert _has_fast_int8_instructions(asm, target) - - # for oc in [4, 16, 20]: - # asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC", - # kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Check that both non-divisible oc and ic work - asm = _compile( - ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes - ) - assert _has_fast_int8_instructions(asm, target) - - # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) - - # Check that int8 x int8 goes through legalization so that fast instructions can be picked up. - for target in targets: - if tvm.testing.device_enabled(target) and llvm_version >= 8: - dtypes = ("int8", "int8", "int32") - # Check that both non-divisible oc and ic work - asm = _compile( - ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout="OIHW", dtypes=dtypes - ) - assert _has_fast_int8_instructions(asm, target) - - # asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # assert _has_fast_int8_instructions(asm, target) + return lib.get_source("asm") + + # Ensure that code uses the fast int8 instructions when available. + @tvm.testing.parametrize_targets(*supported_targets) + @pytest.mark.parametrize( + "dtypes", + [ + # compile conv2d for x86 (skylake, cascadelake) and test + # assembly contains *pmadd* instructions + ("uint8", "int8", "int32"), + # Check that int8 x int8 goes through legalization so that + # fast instructions can be picked up. + ("int8", "int8", "int32"), + ], + ) + def test_uses_intrinsic( + self, + fast_int8_intrinsic, + assembly, + ): + assert fast_int8_intrinsic in assembly - # Ensure that code is generated when datatypes are not HW supported. - # dtypes = ('uint8', 'uint8', 'int32') - # asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO', - # dtypes=dtypes) - # # Check that intrinisic is not present in the assembly. - # assert not _has_fast_int8_instructions(asm, target) + # For datatypes that don't have HW support, ensure that code is + # generated without the fast int8 intrinsic. + @tvm.testing.parametrize_targets(*supported_targets) + @pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")]) + def test_no_intrinsic( + self, + fast_int8_intrinsic, + assembly, + ): + assert fast_int8_intrinsic not in assembly # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. - target = "llvm -mcpu=x86-64" - if tvm.testing.device_enabled(target): - fast_int8_dtypes = ("uint8", "int8", "int32") - asm = _compile( - ic=16, - oc=32, - target=target, - data_layout="NCHW", - kernel_layout="OIHW", - dtypes=fast_int8_dtypes, - ) - # Check that vector int mult and add instructions are generated. - assert "pmulhw" in asm and "paddd" in asm + @tvm.testing.parametrize_targets(*unsupported_targets) + @pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")]) + def test_uses_vectorized_instruction(self, assembly): + assert "pmulhw" in assembly and "paddd" in assembly @tvm.testing.uses_gpu diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index ef5824c957e83..3310b6b2ed690 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -486,8 +486,7 @@ def before(): beta = relay.var("beta") y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3) y = y[0] - y = relay.Function(analysis.free_vars(y), y) - return y + return relay.Function(analysis.free_vars(y), y) def alter_conv2d(attrs, inputs, tinfos, out_type): data, weight = inputs @@ -509,9 +508,8 @@ def expected(): bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c") add = relay.add(y, bias) y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW") - y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC") - mean = relay.mean(y, axis=3, exclude=True) - var = relay.variance(y, axis=3, exclude=True) + mean = relay.mean(y, axis=1, exclude=True) + var = relay.variance(y, axis=1, exclude=True) denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05)) gamma = relay.var("gamma", shape=(16,)) denom = denom * gamma diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 2b7e3e9eb3a9f..9b4d154360b23 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Test alter op layout pass""" +import pytest + import tvm from tvm import te @@ -1098,6 +1100,74 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_qnn_conv_transpose_requantize_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8") + weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8") + x = relay.layout_transform(x, "NHWC", "NCHW") + weight = relay.layout_transform(weight, "HWIO", "OIHW") + y = relay.qnn.op.conv2d_transpose( + x, + weight, + relay.const(1, "int32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "float32"), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + out_dtype="int32", + ) + y = relay.qnn.op.requantize( + y, + relay.const(1, "float32"), + relay.const(1, "int32"), + relay.const(1, "float32"), + relay.const(1, "int32"), + axis=1, + out_dtype="int32", + ) + y = relay.nn.relu(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_convert_kernel_layout(): """Check that convolution kernel layout is correctly transformed.""" @@ -1925,37 +1995,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type assert test_infer_correct_layout_flag == True +def test_reduce_op_convert_layout(): + for reduce_op in [relay.argmax, relay.mean, relay.max]: + + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = reduce_op(y, axis=[2, 3]) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + x = relay.layout_transform(x, "NCHW", "NHWC") + weight = relay.layout_transform(weight, "OIHW", "HWIO") + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = reduce_op(y, axis=[1, 2]) + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": - test_qnn_binary_no_convert_layout() - test_no_convert_layout() - test_conv_convert_layout() - test_conv_nhwc_convert_layout() - test_conv_bias_pool_convert_layout() - test_conv_concat_convert_layout() - test_dual_path_convert_layout() - test_bn_convert_layout() - test_slice_like_convert_layout() - test_transpose_convert_layout() - test_resnet_convert_layout() - test_scalar_convert_layout() - test_conv_bn_convert_layout() - test_qnn_conv_requantize_convert_layout() - test_qnn_conv_concat_convert_layout() - test_qnn_conv_add_convert_layout() - test_qnn_conv_nhwc_convert_layout() - test_conv_convert_kernel_layout() - test_conv_transpose_convert_layout() - test_conv_roi_align_convert_layout() - test_conv_roi_pool_convert_layout() - test_conv_strided_slice_convert_layout() - test_deformable_conv_bias_pool_convert_layout() - test_default_keyword() - test_different_ops_convert_layout() - test_no_desired_layout() - test_convert_with_config() - test_conv_squeeze_convert_layout() - test_conv_reduce_convert_layout() - test_conv_strided_slice_axes_convert_layout() - test_image_resize_convert_layout() - test_conv_image_resize_convert_layout() - test_infer_correct_layout() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index a34c4ac6f705c..5b61733bbd764 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -72,6 +72,31 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +@tvm.testing.uses_gpu +def test_dynamic_to_static_squeeze(): + def verify_squeeze(shape, axis, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(axis, "float32")) + z = relay.squeeze(x, relay.shape_of(y)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("squeeze") + assert "axis=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=axis).astype("float32") + ref_res = np.squeeze(x_data, axis) + verify_func(func2, [x_data, y_data], ref_res) + + verify_squeeze((1, 3, 4, 1), (0,), (3, 4, 1)) + verify_squeeze((1, 3, 4, 1), (3,), (1, 3, 4)) + verify_squeeze((1, 3, 4, 1), (0, 3), (3, 4)) + + @tvm.testing.uses_gpu def test_dynamic_to_static_double_reshape(): def verify_reshape(shape, newshape): diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3680310b4f926..c49d837ed9201 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -268,6 +268,19 @@ def test_fake_quantize_avgpool(): compare_fq_to_int(op, [x_np], True) +def test_fake_quantize_global_avg_pool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.global_avg_pool2d(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np], True) + + def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 0000000000000..2252d8a235c90 --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + + +"""Unit tests for the PlanDevices pass. We check: + - The pass alone given the expected AST, though we need to manually run InferTypes. + - The pass is idempotent. + - Execution on the VM backend yields the correct result.""" + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + +core = tvm.IRModule() +core.import_from_std("core.rly") + + +def rewrite_and_assert(in_mod, expected_mod): + """Manually run the pass and assert it's structurally equals to the expected.""" + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): + """Test the standard compilation flow gives us a function which agrees with the Numpy + reference implementation.""" + if not tvm.runtime.enabled("cuda"): + print("Not evaluating since GPU is not available") + return + with tvm.transform.PassContext(opt_level=3): + compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + actual = compiled(*args).numpy() + expected = reference_func(*args) + tvm.testing.assert_allclose(actual, expected) + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + """Test in_mod against expected_mod and reference_func using args.""" + # Correctness + rewrite_and_assert(in_mod, expected_mod) + # Idempotence + rewrite_and_assert(expected_mod, expected_mod) + # The VM can compile and possibly even run the module + # TODO(mbs): Disabled until VM supports new device planning. + # if not (reference_func is None) and not (args is None): + # eval_and_assert(in_mod, reference_func, args) + + +def test_plain(): + # Everything defaults to GPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[2, 2, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu(): + # Force some args to be on CPU, rest default to GPU. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu_via_copy(): + # As for test_left_add_on_cpu, but with an explicit device_copy. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_both_adds_on_cpu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + subtract(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = add(%c, %d); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + subtract(%4, %5) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_sharing(): + # The same add sub-expression is annotated twice. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = on_device(%0, device_type=1); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = on_device(%0, device_type=1, is_fixed=True); + %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + subtract(%3, %4) + } + """ + ) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_let_on_cpu(): + # The device for a let-bound expression can flow from uses of the let-bound var. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %l = add(%a, %b); + let %r = add(%c, %d); + %0 = on_device(%l, device_type=1); + subtract(%0, %r) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + let %l = on_device(%0, device_type=1, is_fixed=True); + let %r = add(%c, %d); + %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + subtract(%1, %r) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_param_on_cpu(): + # Devices for function parameters flow to call sites. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + %0 = add(%x, %y); + on_device(%0, device_type=1) + }; + %1 = %f(%a, %b); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=1) { + let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_result_on_cpu(): + # Devices for call sites flow to function results. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + let %f = on_device(%0, device_type=1, is_fixed=True); + %1 = %f(%a, %b); + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %4 = add(%c, %d); + subtract(%3, %4) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_higher_order(): + # The constraint on %a flows back to %y via %f and %h + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%g) { + fn (%a) { + %0 = on_device(%a, device_type=1); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %f = fn (%g, param_device_types=[2], result_device_type=2) { + fn (%a, param_device_types=[1], result_device_type=2) { + %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b, param_device_types=[2], result_device_type=2) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def ref(x, y): + def f(g): + return lambda a: np.add(g(a), x) + + def h(b): + return np.negative(b) + + return np.subtract(x, f(h)(y)) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_function_in_tuple(): + # Since %f ends up in a tuple its argument and result is forced to be on the CPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = on_device(%b, device_type=1); + add(%a, %0) + }; + let %t = (%f, %x); + %1 = %t.1; + %2 = %t.0; + %2(%1, %y) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + add(%a, %b) + }; + let %t = (%f, %x); + %0 = %t.1; + %1 = %t.0; + %1(%0, %y) + } + """ + ) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_device_copy(): + const = rand((5, 7)) + metatable = {"relay.Constant": [relay.const(const)]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.add(x, const) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_shape_func(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { + %0 = fn (%y: Tensor[(?), float32]) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%x, device_type=2, is_fixed=True); + %2 = vm.shape_of(%1, dtype="int64"); + %3 = (%2,); + %4 = (%s,); + vm.shape_func(%p, %3, %4, is_input=[False]) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], + param_device_types=[2, 1], result_device_type=1) { + %0 = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = vm.shape_of(%x, dtype="int64"); + %2 = (%1,); + %3 = (%s,); + vm.shape_func(%p, %2, %3, is_input=[False]) + } + """ + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # result defaults to the result device for @main which is the CPU, thus forcing a copy. + # TODO(mbs): Perhaps the defaulting heuristics are being too clever? + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32]) { + %0 = on_device(%x, device_type=2, is_fixed=True); + vm.shape_of(%0, dtype="int64") + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + vm.shape_of(%x, dtype="int64") + } + """ + ) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_alloc_storage(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + shape = np.array([3, 2]) + metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[]) { + memory.alloc_tensor(%sto, 0, meta[relay.Constant][0], + const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { + %0 = on_device(0, device_type=1, is_fixed=True); + %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + newshape = [2, 4, 2] + metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32]) { + vm.reshape_tensor(%x, meta[relay.Constant][0], newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { + %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.reshape(x, newshape) + + exercise(input(), expected(), ref, rands((2, 8), 1)) + + +def test_dynamic_input(): + # There's nothing special about inferring devices for partially unknown types. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { + add(%x0, %x1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], + param_device_types=[2, 2], result_device_type=2) { + add(%x0, %x1) + } + """ + ) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_redundant_annotation(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + %3 = on_device(%0, device_type=1); + add(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2], result_device_type=2) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = subtract(%2, %z); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + add(%4, %5) + } + """ + ) + + def ref(x, y, z): + a = np.add(x, y) + return np.add(np.subtract(a, z), a) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_expr(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[2, 2, 1], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_all(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + %0 = add(%x, %y); + subtract(%0, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32]) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + %4 = add(%2, %3); + %5 = on_device(%4, device_type=2); + %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + on_device(%6, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %6 = add(%4, %5); + %7 = on_device(%6, device_type=2, is_fixed=True); + %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) + } + """ + ) + + # Don't try to execute, we don't have a reference conv2d + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + # Note that the device copy should be placed after projection rather than before. This is handled by + # a heuristic in the pass. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32]) { + let %t = split(%x, indices_or_sections=3); + %0 = on_device(%t, device_type=1); + %1 = on_device(%t, device_type=1); + %2 = %0.0; + %3 = %1.1; + %4 = subtract(%2, %3); + on_device(%4, device_type=2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + %0 = split(%x, indices_or_sections=3); + let %t = on_device(%0, device_type=1, is_fixed=True); + %1 = %t.0; + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = %t.1; + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + subtract(%5, %6) + } + """ + ) + + def ref(x): + t = np.split(x, 3) + return np.subtract(t[0], t[1]) + + exercise(input(), expected(), ref, rands((3, 3, 4), 1)) + + +def test_propogation(): + r""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = log(%x); + %1 = on_device(%0, device_type=1); + %2 = log2(%1); + %3 = on_device(%0, device_type=1); + %4 = log10(%3); + %5 = on_device(%2, device_type=2); + %6 = on_device(%4, device_type=2); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2); + %9 = tan(%8); + on_device(%9, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + %0 = log(%x); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = log2(%2); + %6 = log10(%4); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + tan(%9) + } + """ + ) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_fusible_network(): + r""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + negative \ <--- CPU + \ \ + \ negative <--- GPU + \ / + add <--- GPU + | + negative <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = negative(%1); + %3 = on_device(%2, device_type=1); + %4 = negative(%0); + %5 = add(%3, %4); + %6 = on_device(%5, device_type=2); + %7 = negative(%6); + on_device(%7, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = negative(%2); + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %6 = negative(%0); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + negative(%9) + } + """ + ) + + def ref(x, y): + z = np.add(x, y) + return np.negative(np.add(np.negative(z), np.negative(z))) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ multiply <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = multiply(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=2); + %4 = subtract(%2, %3); + on_device(%4, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=1) { + %0 = multiply(%c, %d); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = add(%a, %b); + %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.multiply(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_conditional(): + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + let %f = fn (%a) { + %0 = on_device(%y, device_type=1, is_fixed=True); + add(%a, %0) + }; + let %g = fn (%a1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + let %f = fn (%a, param_device_types=[1], result_device_type=1) { + add(%a, %y) + }; + let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def ref(x, y, z): + def f(a): + return np.add(a, y) + + def g(a): + return np.subtract(a, y) + + h = f if x else g + return h(z) + + exercise(input(), expected(), ref, [True, rand((5, 7)), rand((5, 7))]) + + +def test_global(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = on_device(%b, device_type=1); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def ref(x, y): + def f(a, b): + return np.add(a, b) + + return f(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_ref(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %r = ref(%x); + %0 = on_device(%y, device_type=1); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %r = ref(%x); + %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def ref(x, y): + r = {"value": x} + r["value"] = y + return np.add(x, r["value"]) + + # Don't try to execute, no backend currently supports both hetrogeneous devices and references. + exercise(input(), expected(), None, None) + + +def test_adt(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { + %0 = on_device(%y, device_type=1, is_fixed=True); + %1 = Nil; + %2 = Cons(%0, %1); + let %l = Cons(%x, %2); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + %0 = Nil; + %1 = Cons(%y, %0); + let %l = Cons(%x, %1); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def ref(x, y): + l = [x, y] + return l[0] + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/relay/utils/external_codegen.py b/tests/python/relay/utils/external_codegen.py index 85583f6ccc5d0..2d73ef85be289 100644 --- a/tests/python/relay/utils/external_codegen.py +++ b/tests/python/relay/utils/external_codegen.py @@ -59,7 +59,7 @@ def parametrize_external_json_codegen_checks(test): def update_lib(lib): test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - source_dir = os.path.join(test_dir, "..", "..", "..") + source_dir = os.path.join(test_dir, "..", "..", "..", "..") contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index 96359860f5695..8c125af721634 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -28,7 +28,7 @@ _conv2d_nhwc_implement = { "generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), - "gpu": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), + "gpu": (topi.gpu.conv2d_nhwc, topi.gpu.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "arm_cpu": ( topi.arm_cpu.conv2d_nhwc_spatial_pack, diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py new file mode 100644 index 0000000000000..51ec9ea87ed3f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule.arg_info import ArgInfo, TensorInfo +from tvm.script import ty + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +def Matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (128, 256), "float32") + B = tir.match_buffer(b, (256, 512), "float32") + C = tir.match_buffer(c, (128, 512), "float32") + with tir.block([128, 256, tir.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_tensor_info_creation(): + info = TensorInfo("float32", [1, 224, 224, 3]) + info = str(info) + assert info == 'TensorInfo("float32", [1, 224, 224, 3])' + + +def test_meta_schedule_tensor_info_as_json(): + info = TensorInfo("float32", [1, 224, 224, 3]) + info = info.as_json() + assert info == ["TENSOR", "float32", [1, 224, 224, 3]] + + +def test_meta_schedule_tensor_info_from_json(): + info = ["TENSOR", "float32", [1, 224, 224, 3]] + info = TensorInfo.from_json(info) + assert str(info) == 'TensorInfo("float32", [1, 224, 224, 3])' + + +def test_meta_schedule_arg_info_from_prim_func(): + a_info, b_info, c_info = ArgInfo.from_prim_func(Matmul) + assert str(a_info) == 'TensorInfo("float32", [128, 256])' + assert str(b_info) == 'TensorInfo("float32", [256, 512])' + assert str(c_info) == 'TensorInfo("float32", [128, 512])' + + +if __name__ == "__main__": + test_meta_schedule_tensor_info_creation() + test_meta_schedule_tensor_info_as_json() + test_meta_schedule_tensor_info_from_json() + test_meta_schedule_arg_info_from_prim_func() diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py new file mode 100644 index 0000000000000..feef023675b04 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +"""Test Meta Schedule Database""" +import os.path as osp +import sys +import tempfile +from typing import Callable + +import pytest + +import tvm +from tvm import tir +from tvm.ir.module import IRModule +from tvm.meta_schedule.arg_info import ArgInfo +from tvm.meta_schedule.database import JSONDatabase, TuningRecord +from tvm.script import ty +from tvm.tir import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulRelu: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + D = tir.match_buffer(d, (16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_tiles = [1, 1, 2, 512] + j_tiles = [1, 512, 1, 2] + k_tiles = [256, 4] + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=i_tiles) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=j_tiles) + k_0, k_1 = sch.split(loop=k, factors=k_tiles) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def _create_schedule(mod: IRModule, sch_fn: Callable[[Schedule], None]) -> Schedule: + sch = tir.Schedule(mod=mod, debug_mask="all") + sch_fn(sch) + return sch + + +def _create_tmp_database(tmpdir: str) -> JSONDatabase: + path_workload = osp.join(tmpdir, "workloads.json") + path_tuning_record = osp.join(tmpdir, "tuning_records.json") + return JSONDatabase(path_workload, path_tuning_record) + + +def _equal_record(a: TuningRecord, b: TuningRecord): + assert str(a.trace) == str(b.trace) + assert str(a.run_secs) == str(b.run_secs) + # AWAIT(@zxybazh): change to export after fixing "(bool)0" + assert str(a.target) == str(b.target) + assert tvm.ir.structural_equal(a.workload.mod, b.workload.mod) + for arg0, arg1 in zip(a.args_info, b.args_info): + assert str(arg0.as_json()) == str(arg1.as_json()) + + +def test_meta_schedule_tuning_record_round_trip(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + new_record = TuningRecord.from_json(record.as_json(), workload) + _equal_record(record, new_record) + + +def test_meta_schedule_database_create(): + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + assert osp.exists(database.path_workload) + assert osp.exists(database.path_tuning_record) + + +def test_meta_schedule_database_add_entry(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + assert len(database) == 1 + (ret,) = database.get_top_k(workload, 3) + _equal_record(ret, record) + + +def test_meta_schedule_database_missing(): + mod: IRModule = Matmul() + mod_2: IRModule = MatmulRelu() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + workload = database.commit_workload(mod) + workload_2 = database.commit_workload(mod_2) + record = TuningRecord( + _create_schedule(mod, _schedule_matmul).trace, + [1.5, 2.5, 1.8], + workload, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ) + database.commit_tuning_record(record) + ret = database.get_top_k(workload_2, 3) + assert len(ret) == 0 + + +def test_meta_schedule_database_sorting(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + token = database.commit_workload(mod) + trace = _create_schedule(mod, _schedule_matmul).trace + records = [ + TuningRecord( + trace, + [7.0, 8.0, 9.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 2.0, 3.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 5.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.1, 1.2, 600.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 100.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 9.0, 8.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + ] + for record in records: + database.commit_tuning_record(record) + ret = database.get_top_k(token, 2) + assert len(ret) == 2 + try: + _equal_record(ret[0], records[2]) + _equal_record(ret[1], records[1]) + except AssertionError: + _equal_record(ret[0], records[1]) + _equal_record(ret[1], records[2]) + + +def test_meta_schedule_database_reload(): + mod: IRModule = Matmul() + with tempfile.TemporaryDirectory() as tmpdir: + database = _create_tmp_database(tmpdir) + token = database.commit_workload(mod) + trace = _create_schedule(mod, _schedule_matmul).trace + records = [ + TuningRecord( + trace, + [7.0, 8.0, 9.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [1.0, 2.0, 3.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + TuningRecord( + trace, + [4.0, 5.0, 6.0], + token, + tvm.target.Target("llvm"), + ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object + ), + ] + for record in records: + database.commit_tuning_record(record) + new_database = JSONDatabase( # pylint: disable=unused-variable + path_workload=database.path_workload, + path_tuning_record=database.path_tuning_record, + ) + token = new_database.commit_workload(mod) + ret = new_database.get_top_k(token, 2) + assert len(ret) == 2 + try: + _equal_record(ret[0], records[2]) + _equal_record(ret[1], records[1]) + except AssertionError: + _equal_record(ret[0], records[1]) + _equal_record(ret[1], records[2]) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py new file mode 100644 index 0000000000000..3c8aee0c6d58f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -0,0 +1,571 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Runner """ + +import itertools +import sys +import time +from typing import Any, List + +import numpy as np +import pytest + +import tvm +from tvm import tir +from tvm._ffi import register_func +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import ( + EvaluatorConfig, + PyRunner, + RPCConfig, + RPCRunner, + RunnerFuture, + RunnerInput, +) +from tvm.meta_schedule.runner.rpc_runner import ( + default_alloc_argument as rpc_default_alloc_argument, +) +from tvm.meta_schedule.testing import LocalRPC +from tvm.meta_schedule.utils import get_global_func_with_default_on_worker +from tvm.rpc import RPCSession +from tvm.runtime import Device, Module +from tvm.script import ty +from tvm.target import Target +import tvm.testing +from tvm.tir import FloatImm + +MATMUL_N = 16 +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking + + +@tvm.script.tir +class MatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +class MatmulReluModule: + def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + D = tir.match_buffer(d, (16, 16), "float32") + C = tir.alloc_buffer((16, 16), "float32") + with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + with tir.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = tir.max(C[vi, vj], 0.0) + + +@tvm.script.tir +class BatchMatmulModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16, 32, 32]) + B = tir.match_buffer(b, [16, 32, 32]) + C = tir.match_buffer(c, [16, 32, 32]) + with tir.block([16, 32, 32, tir.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: + with tir.init(): + C[vn, vi, vj] = 0.0 + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@tvm.script.tir +class AddModule: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [32], "float32") + B = tir.match_buffer(b, [32], "float32") + C = tir.match_buffer(c, [32], "float32") + with tir.block([32], "add") as [vi]: + C[vi] = A[vi] + B[vi] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def _clean_build(artifact_path: str) -> None: + f_clean_build = get_global_func_with_default_on_worker("meta_schedule.remove_build_dir", None) + if f_clean_build is not None: + f_clean_build(artifact_path) + else: + raise RuntimeError("Unable to find remove_build_dir function.") + + +def test_meta_schedule_rpc_single_run(): + """Test meta schedule rpc runner for a single run""" + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_rpc_multiple_runs(): + """Test meta schedule rpc runner for multiple runs""" + # Build the module + mods = [ + MatmulModule(), + MatmulReluModule(), + BatchMatmulModule(), + ] + builder = LocalBuilder() + builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] + builder_results = builder.build(builder_inputs) + for builder_result in builder_results: + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + args_infos = [ + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + [ + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + TensorInfo("float32", [16, MATMUL_M, MATMUL_M]), + ], + ] + + runner_inputs = [ + RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i]) + for i in range(len(mods)) + ] + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner(rpc_config, evaluator_config) + # Run the module + runner_futures = runner.run(runner_inputs) + runner_results = [runner_future.result() for runner_future in runner_futures] + + for runner_result in runner_results: + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + for builder_result in builder_results: + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_py_runner(): + """Test meta schedule PyRunner""" + + class TestRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + raise ValueError("TestRunner") + + runner = TestRunner() + with pytest.raises(ValueError, match="TestRunner"): + runner.run([]) + + +def test_meta_schedule_rpc_runner_time_out(): + """Test meta schedule RPC Runner time out""" + + def initializer(): + @register_func("meta_schedule.runner.test_time_out") + def timeout_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + time.sleep(2) + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=1, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_time_out", + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: Timeout, killed after" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_rpc_runner_exception(): + """Test meta schedule RPC Runner exception""" + + def initializer(): + @register_func("meta_schedule.runner.test_exception") + def exception_session_creator( # pylint: disable=unused-variable + rpc_config: RPCConfig, # pylint: disable=unused-argument + ) -> RPCSession: + raise Exception("Test") + + runner_input = RunnerInput( + "test", + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + initializer=initializer, + f_create_session="meta_schedule.runner.test_exception", + ) + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + + assert runner_result.error_msg is not None and runner_result.error_msg.startswith( + "RPCRunner: An exception occurred\n" + ) + assert runner_result.run_secs is None + + +def test_meta_schedule_runner_matmul_test(): + """Test meta schedule runner with add module""" + + def _check_correct_matmul( + args_before: List[np.ndarray], + args_after: List[np.ndarray], + ) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = np.matmul(a_before, b_before) + assert (a_before == a_after).all() + assert (b_before == b_after).all() + tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5) + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument(session, device, args_info, alloc_repeat) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_matmul(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = MatmulModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + TensorInfo("float32", (MATMUL_N, MATMUL_N)), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +def test_meta_schedule_runner_add_test(): + """Test meta schedule runner with add module""" + + def _check_correct_add(args_before: List[np.ndarray], args_after: List[np.ndarray]) -> None: + a_before, b_before, c_before = args_before + a_after, b_after, c_after = args_after + c_before = a_before + b_before + assert (a_before == a_after).all() + assert (b_before == b_after).all() + assert (c_before == c_after).all() + + def test_alloc_argument( + session: RPCSession, + device: Device, + args_info: Any, + alloc_repeat: int, + ) -> List[Any]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_before = [] # type: ignore + repeated_args = rpc_default_alloc_argument( + session, + device, + args_info, + alloc_repeat, + ) + for args in repeated_args: + repeated_args_before.append([arg.numpy() for arg in args]) # type: ignore + return repeated_args + + def test_run_evaluator( + session: RPCSession, # pylint: disable=unused-argument + rt_mod: Module, + device: Device, + evaluator_config: EvaluatorConfig, + repeated_args: List[Any], + ) -> List[float]: + global repeated_args_before # pylint: disable=global-variable-undefined, invalid-name + repeated_args_after = [] + evaluator = rt_mod.time_evaluator( + func_name=rt_mod.entry_name, + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs: List[List[float]] = [] + for args in repeated_args: + device.sync() + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + repeated_args_after.append([arg.numpy() for arg in args]) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + for args_before, args_after in zip( + repeated_args_before, # type: ignore + repeated_args_after, + ): + _check_correct_add(args_before, args_after) + del repeated_args_before # type: ignore + return costs + + # Build the module + mod = AddModule() + builder = LocalBuilder() + (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) + assert builder_result.artifact_path is not None + assert builder_result.error_msg is None + + runner_input = RunnerInput( + builder_result.artifact_path, + "llvm", + [ + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + TensorInfo("float32", [MATMUL_M]), + ], + ) + + with LocalRPC() as rpc: + rpc_config = RPCConfig( + tracker_host=rpc.tracker_host, + tracker_port=rpc.tracker_port, + tracker_key=rpc.tracker_key, + session_priority=1, + session_timeout_sec=100, + ) + evaluator_config = EvaluatorConfig( + number=1, + repeat=1, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ) + runner = RPCRunner( + rpc_config, + evaluator_config, + f_alloc_argument=test_alloc_argument, + f_run_evaluator=test_run_evaluator, + ) + # Run the module + (runner_future,) = runner.run([runner_input]) + runner_result = runner_future.result() + assert runner_result.error_msg is None + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + _clean_build(builder_result.artifact_path) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py new file mode 100644 index 0000000000000..6e90bddb84b41 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SearchStrategy """ +# pylint: disable=missing-function-docstring +from typing import List + +import sys + +import pytest + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace + +from tvm.script import ty +from tvm.tir.schedule import Schedule, Trace + + +MATMUL_M = 32 + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (32, 32), "float32") + B = tir.match_buffer(b, (32, 32), "float32") + C = tir.match_buffer(c, (32, 32), "float32") + with tir.block([32, 32, tir.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + return str(trace_1) == str(trace_2) + + +def _schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_replay_trace(): + num_trials_per_iter = 7 + num_trials_total = 20 + + (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul()) + replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul()) + replay.initialize_with_tune_context(tune_context) + + num_trials_each_round: List[int] = [] + replay.pre_tuning([example_sch]) + while True: + candidates = replay.generate_measure_candidates() + if candidates is None: + break + num_trials_each_round.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + assert _is_trace_equal(candidate.sch, example_sch) + runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) + replay.notify_runner_results(runner_results) + replay.post_tuning() + assert num_trials_each_round == [7, 7, 6] + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py new file mode 100644 index 0000000000000..3ab60aced197e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SpaceGenerator """ +# pylint: disable=missing-function-docstring + +import sys +import math + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty + +from tvm.tir.schedule import Schedule, Trace +from tvm.meta_schedule.space_generator import ScheduleFn, SpaceGeneratorUnion + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_space_generator_schedule_fn(): + mod = Matmul() + space_generator = ScheduleFn(sch_fn=schedule_matmul) + design_spaces = space_generator.generate_design_space(mod) + assert len(design_spaces) == 1 + (schedule,) = design_spaces + _check_correct(schedule) + + +def test_meta_schedule_design_space_generator_union(): + mod = Matmul() + space_generator = ScheduleFn(sch_fn=schedule_matmul) + space_generator_union = SpaceGeneratorUnion([space_generator, space_generator]) + design_spaces = space_generator_union.generate_design_space(mod) + assert len(design_spaces) == 2 + for design_space in design_spaces: + _check_correct(design_space) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py new file mode 100644 index 0000000000000..2da4c85ab4218 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test the tune context of meta schedule.""" + +import sys +import pytest + +import tvm +from tvm import tir +from tvm.script import ty +from tvm.target import Target +from tvm.meta_schedule import TuneContext + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring + + +def test_tune_context_create(): + mod = Matmul() + context = TuneContext(mod=mod, target=Target("llvm"), task_name="Test Task") + assert context.num_threads > 0 + assert context.rand_state != -1 + assert context.task_name == "Test Task" + assert context.mod == mod or tvm.ir.structural_equal(context.mod, mod) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_target_codegen_hexagon.py b/tests/python/unittest/test_target_codegen_hexagon.py index c8b48993967b2..79f2fb06a1ec2 100644 --- a/tests/python/unittest/test_target_codegen_hexagon.py +++ b/tests/python/unittest/test_target_codegen_hexagon.py @@ -17,30 +17,27 @@ import numpy as np import os +import pytest import re +import sys import tvm import tvm.relay +import tvm.testing import tvm.contrib.hexagon as hexagon -def check_prereq_and_setup(): - if tvm.target.codegen.llvm_version_major() <= 7: - print("Skipping test: need LLVM 7 or later for codegen") - return False - if os.name != "posix": - print("Skipping test on non-POSIX platforms") - return False - if not tvm.runtime.enabled("hexagon"): - print("Hexagon runtime not enabled") - return False +@pytest.fixture(autouse=True) +def register_linker(): + original_linker = tvm.contrib.hexagon.hexagon_link() # Register a phony linker, so that we can test codegen without a Hexagon toolchain. hexagon.register_linker(lambda: "/bin/true") - return True + yield None + # Restore registration. + hexagon.register_linker(original_linker) +@tvm.testing.requires_hexagon def test_basic(): - if not check_prereq_and_setup(): - return target = tvm.target.hexagon("v66", hvx=128) def check_add(offload): @@ -67,9 +64,8 @@ def check_add(offload): check_add(False) +@tvm.testing.requires_hexagon def test_llvm_target_features(): - if not check_prereq_and_setup(): - return target = tvm.target.hexagon("v66", hvx=128) # Define some trivial compute A = tvm.te.placeholder((128,), dtype="uint8", name="A") @@ -82,9 +78,8 @@ def test_llvm_target_features(): assert fs # Check that it's non-empty +@tvm.testing.requires_hexagon def test_alloc_vtcm(): - if not check_prereq_and_setup(): - return target = tvm.target.hexagon("v66") buf_len = 2048 @@ -109,9 +104,8 @@ def test_alloc_vtcm(): assert "HexagonBackendFreeVTCM" in calls +@tvm.testing.requires_hexagon def test_llvm_options(): - if not check_prereq_and_setup(): - return target = tvm.target.hexagon("v66", llvm_options="-hexagon-noopt") Zero = tvm.te.compute((10,), lambda _: tvm.tir.const(0, "int32")) s = tvm.te.create_schedule(Zero.op) @@ -121,10 +115,8 @@ def test_llvm_options(): assert re.search("-hexagon-noopt", str(target)) +@tvm.testing.requires_hexagon def test_linked_params_codegen(): - if not check_prereq_and_setup(): - return - # A simple model (a single conv2d) to trigger parameter separation: mod_lines = [ '#[version = "0.0.5"]', @@ -185,8 +177,4 @@ def test_linked_params_codegen(): if __name__ == "__main__": - test_basic() - test_llvm_target_features() - test_alloc_vtcm() - test_llvm_options() - test_linked_params_codegen() + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 78a8c51178490..efb2073e08625 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -156,6 +156,54 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: ) +@tvm.script.tir +def high_dim_opaque_access(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 32, 64)) + for i, j, k in tir.grid(16, 2, 4): + with tir.block([]): + As_0 = tir.var("int32") + As_1 = tir.var("int32") + tir.reads([]) + tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + sub_A = tir.match_buffer( + A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], + (16, 16), + strides=[As_0, As_1], + offset_factor=1, + ) + tir.evaluate( + tir.intrin_test( + sub_A.data, + sub_A.elem_offset, + sub_A.strides[0], + sub_A.strides[1], + sub_A.shape[0], + sub_A.shape[1], + dtype="handle", + ) + ) + + +@tvm.script.tir +def transformed_high_dim_opaque_access(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 32, 64)) + for i, j, k in tir.grid(16, 2, 4): + with tir.block([]): + tir.reads([]) + tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + tir.evaluate( + tir.intrin_test( + A.data, + i * 2048 + j * 1024 + k * 16, + 64, + 1, + 16, + 16, + dtype="handle", + ) + ) + + @tvm.script.tir def recursive_match(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (64, 64, 64)) @@ -419,6 +467,10 @@ def test_opaque_access(): _check(opaque_access, transformed_opaque_access) +def test_high_dim_opaque_access(): + _check(high_dim_opaque_access, transformed_high_dim_opaque_access) + + def test_recursive_match(): _check(recursive_match, transformed_recursive_match) @@ -447,6 +499,7 @@ def test_fail_match_func_param(): if __name__ == "__main__": test_buffer_load_store() test_opaque_access() + test_high_dim_opaque_access() test_recursive_match() test_symbolic_match() test_rank0_buffer() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index dbae0b6fa516d..de94464187b0d 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import te import numpy as np @@ -104,6 +105,11 @@ def test_cast(): assert isinstance(z, tvm.tir.Broadcast) assert z.lanes == 4 + s = tvm.tir.StringImm("s") + with pytest.raises(tvm.error.TVMError) as cm: + s.astype("int") + assert "Can't cast a handle to other types" in str(cm.execption) + def test_attr(): x = te.var("x") @@ -468,28 +474,4 @@ def test_block_blockrealize(): if __name__ == "__main__": - test_intimm_cond() - test_buffer_load_store() - test_vars() - test_prim_func() - test_cast() - test_attr() - test_const() - test_scalar_dtype_inference() - test_make() - test_ir() - test_basic() - test_stmt() - test_let() - test_dir() - test_dtype() - test_any() - test_all() - test_bitwise() - test_float_bitwise() - test_shift_bounds() - test_divide_by_zero() - test_isnan() - test_equality() - test_equality_string_imm() - test_block_blockrealize() + pytest.main([__file__]) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 2284f9d996b1f..d11e7f877ccca 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -34,6 +34,16 @@ def elementwise(a: ty.handle, b: ty.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 +@tvm.script.tir +def elementwise_dependent_loops(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i in tir.serial(0, 128): + for j, k in tir.grid(i, 128): + with tir.block([128, i, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + @tvm.script.tir def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) @@ -462,5 +472,13 @@ def test_split_symbolic(): verify_trace_roundtrip(sch=sch, mod=elementwise_symbolic) +def test_fuse_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_dependent_loops, debug_mask="all") + block_b = sch.get_block("B") + i, j, _ = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(i, j) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index c632f744bb815..a219b8d964573 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -17,6 +17,8 @@ import tvm import tvm.testing from tvm import te +from tvm import tir +from tvm.script import ty import numpy @@ -434,7 +436,6 @@ def test_conv_tiling(): oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.tir.transform.LoopPartition()(mod) @@ -538,6 +539,33 @@ def test_simple_rfactor(): assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) +@tvm.script.tir +def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + A = tir.match_buffer(a, [16], dtype="float32") + B = tir.match_buffer(b, [16], dtype="float32") + C = tir.match_buffer(c, [32], dtype="float32") + for i in tir.serial(0, 16): + tir.store(C.data, i, tir.load("float32", A.data, i), True) + for i in tir.serial(0, 16): + tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True) + + +def test_explicit_partition_hint(): + A = te.placeholder((16,), name="A") + B = te.placeholder((16,), name="B") + C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C") + s = te.create_schedule(C.op) + s.normalize() + s[C].pragma(s[C].op.axis[0], "loop_partition_hint") + mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None) + with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = tvm.tir.transform.Simplify()(mod) + assert tvm.ir.structural_equal(mod["main"], partitioned_concat) + + if __name__ == "__main__": test_basic() test_const_loop() @@ -559,3 +587,4 @@ def test_simple_rfactor(): test_double_splitting_with_indivisible_factors() test_multilevel_splitting_with_indivisble_factors() test_simple_rfactor() + test_explicit_partition_hint() diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 5f86476c64c7e..3a429721709ea 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -26,6 +26,7 @@ cp ../cmake/config.cmake . echo set\(USE_CUBLAS ON\) >> config.cmake echo set\(USE_CUDNN ON\) >> config.cmake echo set\(USE_CUDA ON\) >> config.cmake +echo set\(USE_VULKAN ON\) >> config.cmake echo set\(USE_OPENGL ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_other.sh b/tests/scripts/task_config_build_gpu_other.sh new file mode 100755 index 0000000000000..c11669a2ab0d4 --- /dev/null +++ b/tests/scripts/task_config_build_gpu_other.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is a compiler test to ensure that runtimes can compile +# correctly, even if they aren't actively tested in the CI. + +set -e +set -u + +mkdir -p build2 +cd build2 +cp ../cmake/config.cmake . + +echo set\(USE_OPENCL ON\) >> config.cmake +echo set\(USE_ROCM ON\) >> config.cmake +echo set\(USE_MICRO ON\) >> config.cmake +echo set\(USE_PROFILER ON\) >> config.cmake +echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake +echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(USE_CCACHE OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu_vulkan.sh b/tests/scripts/task_config_build_gpu_vulkan.sh index a5a26a1db0fb8..93adc9667da7b 100755 --- a/tests/scripts/task_config_build_gpu_vulkan.sh +++ b/tests/scripts/task_config_build_gpu_vulkan.sh @@ -16,18 +16,13 @@ # specific language governing permissions and limitations # under the License. -set -e -set -u +# TODO(Lunderberg): Remove this file once the Jenkinsfile in the +# ci-docker-staging branch no longer references it. -mkdir -p build2 -cd build2 -cp ../cmake/config.cmake . +# This file is a backwards compatibility file, as the TVM CI uses the +# Jenkinsfile from the ci-docker-staging branch, but the task scripts +# from the PR branch. -echo set\(USE_OPENCL ON\) >> config.cmake -echo set\(USE_ROCM ON\) >> config.cmake -echo set\(USE_VULKAN ON\) >> config.cmake -echo set\(USE_MICRO ON\) >> config.cmake -echo set\(USE_PROFILER ON\) >> config.cmake -echo set\(USE_LIBBACKTRACE OFF\) >> config.cmake -echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake -echo set\(USE_CCACHE OFF\) >> config.cmake +set -euo pipefail + +./tests/scripts/task_config_build_gpu_other.sh diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 05d1c238b64f2..ecc8ba5d17b05 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -29,5 +29,6 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ -echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." -mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ +#TODO(@mikepapadim): This is failing atm +# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu package." +# mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/ diff --git a/tests/scripts/task_python_integration_gpuonly.sh b/tests/scripts/task_python_integration_gpuonly.sh index ac09cb5a14a38..36c3883d4379a 100755 --- a/tests/scripts/task_python_integration_gpuonly.sh +++ b/tests/scripts/task_python_integration_gpuonly.sh @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;nvptx;opencl -device=mali,aocl_sw_emu" export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" export TVM_RELAY_TEST_TARGETS="cuda" export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu diff --git a/tests/scripts/task_python_unittest_gpuonly.sh b/tests/scripts/task_python_unittest_gpuonly.sh index 22f79bc70ec91..54dd085f18172 100755 --- a/tests/scripts/task_python_unittest_gpuonly.sh +++ b/tests/scripts/task_python_unittest_gpuonly.sh @@ -16,8 +16,22 @@ # specific language governing permissions and limitations # under the License. -export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;vulkan;nvptx;opencl -device=mali,aocl_sw_emu" -export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" +set -euo pipefail + +export PYTEST_ADDOPTS="-m gpu ${PYTEST_ADDOPTS:-}" + +# Test most of the enabled runtimes here. +export TVM_TEST_TARGETS="cuda;opencl;metal;rocm;nvptx;opencl -device=mali,aocl_sw_emu" export TVM_UNITTEST_TESTSUITE_NAME=python-unittest-gpu ./tests/scripts/task_python_unittest.sh + +# Kept separate to avoid increasing time needed to run CI, testing +# only minimal functionality of Vulkan runtime. +export TVM_TEST_TARGETS="vulkan -from_device=0" +export TVM_UNITTEST_TESTSUITE_NAME=python-unittest-vulkan + +source tests/scripts/setup-pytest-env.sh + +run_pytest ctypes ${TVM_UNITTEST_TESTSUITE_NAME} tests/python/unittest/test_target_codegen_vulkan.py +run_pytest cython ${TVM_UNITTEST_TESTSUITE_NAME} tests/python/unittest/test_target_codegen_vulkan.py diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 468c4d40b942e..67cdfdedce0e8 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -69,20 +69,6 @@ def example(): return relay.Function([x, weight], z2) -############################################################################### -# Let us register layout alteration for a conv2d op so that we can apply the -# layout alteration pass on the example. How alter layout pass works is out -# the scope of this tutorial. - - -@relay.op.register_alter_op_layout("nn.conv2d", level=101) -def alter_conv2d(attrs, inputs, tinfos, out_type): - data, weight = inputs - new_attrs = dict(attrs) - new_attrs["data_layout"] = "NCHW16c" - return relay.nn.conv2d(data, weight, **new_attrs) - - ############################################################################### # Optimize the Program # -------------------- @@ -188,21 +174,6 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): mod3 = seq(mod) print(mod3) -############################################################################### -# The passes applied so far are target independent. The pass infra also -# provides a means to make pass target-aware. For example, the layout -# alteration pass falls in such category. - -with tvm.transform.PassContext(opt_level=3): - mod4 = seq(mod) -print(mod4) - -seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()]) -with tvm.transform.PassContext(opt_level=3): - with tvm.target.Target("llvm"): - mod5 = seq1(mod) -print(mod5) - ############################################################################## # Implement a Pass Using Python Decorator # ------------------------------------------ @@ -257,7 +228,6 @@ def visit_constant(self, c): tvm.transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(), - relay.transform.AlterOpLayout(), ] )