diff --git a/.clang-tidy b/.clang-tidy index 5c2a7aa65..1681ed66e 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,6 +1,6 @@ --- InheritParentConfig: true -ExtraArgs: ['-v'] +ExtraArgs: [] FormatStyle: file UseColor: true WarningsAsErrors: '*' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5967a2efe..4d587c640 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,10 +22,12 @@ env: PYTHONDEVMODE: "1" PYTHONUNBUFFERED: "1" PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup COLUMNS: "100" FORCE_COLOR: "1" CLICOLOR_FORCE: "1" UV_INDEX_STRATEGY: "unsafe-best-match" + UV_HTTP_TIMEOUT: "600" XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated UV_CACHE_DIR: "${{ github.workspace }}/.cache/uv" # to be updated @@ -44,7 +46,7 @@ jobs: submodules: recursive - name: Setup Python 3.8 - id: setup-py38 + id: setup-pylowest uses: actions/setup-python@v6 with: python-version: "3.8" # use lowest supported version for linting @@ -52,7 +54,7 @@ jobs: - name: Check AST with Python 3.8 run: | - "${{ steps.setup-py38.outputs.python-path }}" -m compileall -q -f tilelang + "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang - name: Setup Python 3.12 uses: actions/setup-python@v6 diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 6674574c3..605d57ced 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -108,14 +108,11 @@ jobs: - { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" } - { runner: macos-latest, toolkit: "Metal" } python-version: - - "3.8" - # TVM is built with Python 3.8 Limited API, it should work with all Python >= 3.8. - # - "3.9" - # - "3.10" - # - "3.11" - # - "3.12" - # - "3.13" - # - "3.14" + # Wheels are built with Python 3.8 Limited API, they should work with all Python >= 3.8. + # Only build wheels against Python 3.8 Limited API to save CI resources. + # FIXME: Here we use Python 3.9 because our dependency `apache-tvm-ffi` claims to support + # Python 3.8 but it depends on a version of `ml-dtypes` that requires Python >= 3.9. + - "3.9" fail-fast: false timeout-minutes: 120 runs-on: ${{ matrix.target.runner }} diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml index 57af8ea6c..37da4e3c8 100644 --- a/.github/workflows/pr-perfbench-bot.yml +++ b/.github/workflows/pr-perfbench-bot.yml @@ -12,6 +12,17 @@ concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: true # always cancel in-progress +env: + PYTHONDEVMODE: "1" + PYTHONUNBUFFERED: "1" + PYTHONPATH: "" # explicit cleanup + PIP_USER: "" # explicit cleanup + COLUMNS: "100" + FORCE_COLOR: "1" + CLICOLOR_FORCE: "1" + XDG_CACHE_HOME: "${{ github.workspace }}/.cache" # to be updated + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" # to be updated + jobs: perfbench: name: Benchmark between PR and main @@ -31,7 +42,12 @@ jobs: - name: Setup Python uses: actions/setup-python@v6 with: - python-version: "3.9" + python-version: "3.12" + update-environment: true + cache: pip + cache-dependency-path: | + pyproject.toml + requirements*.txt - name: Install merged version run: | diff --git a/3rdparty/tvm b/3rdparty/tvm index 5bf17a346..0f1ebab7b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 5bf17a34602931e7d7e01cbccf358a21fe972779 +Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c diff --git a/CMakeLists.txt b/CMakeLists.txt index afeccaceb..e53650f73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,11 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "$ENV{CIBUILDWHEEL}") + # Warning came from tvm submodule + string(APPEND CMAKE_CXX_FLAGS " -Wno-dangling-reference") +endif() + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.gitmodules" AND EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/.git") @@ -36,9 +41,18 @@ endif() find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) + message(STATUS "Using ccache: ${CCACHE_PROGRAM}") set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "C compiler launcher") set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") +else() + find_program(SCCACHE_PROGRAM sccache) + if(SCCACHE_PROGRAM) + message(STATUS "Using sccache: ${SCCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "C compiler launcher") + set(CMAKE_CXX_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CXX compiler launcher") + set(CMAKE_CUDA_COMPILER_LAUNCHER "${SCCACHE_PROGRAM}" CACHE STRING "CUDA compiler launcher") + endif() endif() # Configs @@ -68,8 +82,6 @@ file(GLOB TILE_LANG_SRCS src/target/utils.cc src/target/codegen_cpp.cc src/target/rt_mod_cpp.cc - # webgpu doesn't have system dependency - src/target/codegen_webgpu.cc # intrin_rule doesn't have system dependency src/target/intrin_rule*.cc ) @@ -181,18 +193,18 @@ install(TARGETS tilelang_cython_wrapper # let libtilelang to search tvm/tvm_runtime in same dir if(APPLE) - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path") -else() - set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN") - set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN") + set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") + set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") +elseif(UNIX) + set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") + set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") endif() -install(TARGETS tvm tvm_runtime tilelang_module tilelang LIBRARY DESTINATION tilelang/lib) - -# Copy tvm cython ext for wheels -# TODO: not necessary for editable builds -if(TVM_BUILD_FROM_SOURCE) - add_dependencies(tilelang tvm_cython) - install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python/tvm/ffi/core.abi3.so" DESTINATION tilelang/3rdparty/tvm/python/tvm/ffi/) -endif() +install( + TARGETS tvm tvm_runtime tilelang_module tilelang + LIBRARY DESTINATION tilelang/lib +) diff --git a/cmake/load_tvm.cmake b/cmake/load_tvm.cmake index 21fe6dfb5..f013c3ba6 100644 --- a/cmake/load_tvm.cmake +++ b/cmake/load_tvm.cmake @@ -11,8 +11,17 @@ endif() set(TVM_INCLUDES ${TVM_SOURCE}/include - ${TVM_SOURCE}/ffi/include ${TVM_SOURCE}/src ${TVM_SOURCE}/3rdparty/dlpack/include ${TVM_SOURCE}/3rdparty/dmlc-core/include ) + +if(EXISTS ${TVM_SOURCE}/ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/ffi/include) +elseif(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/include) +endif() + +if(EXISTS ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) + list(APPEND TVM_INCLUDES ${TVM_SOURCE}/3rdparty/tvm-ffi/3rdparty/dlpack/include) +endif() diff --git a/examples/gemm/README.md b/examples/gemm/README.md index 059d08c84..d7833c97d 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -4,20 +4,23 @@ TileLang is a domain-specific language designed to simplify the process of writi ## Table of Contents -1. [Getting Started](#getting-started) -2. [Simple GEMM Example](#simple-gemm-example) - - [Code Walkthrough](#code-walkthrough) - - [Compiling and Profiling](#compiling-and-profiling) -3. [Advanced GEMM Features](#advanced-gemm-features) - - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) - - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) - - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) -4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) -5. [Verifying Correctness](#verifying-correctness) -6. [Fine-grained MMA Computations](#fine-grained-mma-computations) - - [Example Workflow](#example-workflow) - - [Summary](#summary) -7. [References](#references) +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Installation](#installation) +- [Simple GEMM Example](#simple-gemm-example) + - [Code Walkthrough](#code-walkthrough) + - [Compiling and Profiling](#compiling-and-profiling) +- [Advanced GEMM Features](#advanced-gemm-features) + - [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling) + - [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining) + - [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality) +- [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations) +- [Verifying Correctness](#verifying-correctness) +- [Fine-grained MMA Computations](#fine-grained-mma-computations) + - [Example Workflow](#example-workflow) + - [Summary](#summary) +- [References](#references) --- @@ -25,10 +28,10 @@ TileLang is a domain-specific language designed to simplify the process of writi ### Prerequisites -- **Python 3.8+** -- **NVIDIA GPU** with a recent CUDA toolkit installed +- **Python 3.8+** +- **NVIDIA GPU** with a recent CUDA toolkit installed - **PyTorch** (optional, for easy correctness verification) -- **tilelang** +- **tilelang** - **bitblas** (optional; used for swizzle layout utilities in the advanced examples) ### Installation @@ -87,26 +90,26 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ### Code Walkthrough -1. **Define the Kernel Launch Configuration:** +1. **Define the Kernel Launch Configuration:** ```python with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): ``` This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads. -2. **Shared Memory Allocation:** +2. **Shared Memory Allocation:** ```python A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) ``` Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access. -3. **Local Fragment Accumulation:** +3. **Local Fragment Accumulation:** ```python C_local = T.alloc_fragment((block_M, block_N), accum_dtype) ``` Partial results are stored in registers (or local memory) to reduce writes to global memory. -4. **Pipelined Loading and GEMM:** +4. **Pipelined Loading and GEMM:** ```python for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(...) @@ -114,7 +117,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ``` Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation. -5. **Copy Out the Results:** +5. **Copy Out the Results:** ```python T.copy(C_local, C[by * block_M, bx * block_N]) ``` @@ -216,10 +219,10 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo return main ``` -**Key Differences vs. Basic Example** -1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). -2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. -3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. +**Key Differences vs. Basic Example** +1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling). +2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization. +3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions. --- @@ -247,7 +250,7 @@ print("Results match!") ## Fine-grained MMA Computations -For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. +For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points. ### Example Workflow @@ -394,10 +397,10 @@ def tl_matmul( ] ``` -1. **Set Up Tile Sizes and Thread Bindings** +1. **Set Up Tile Sizes and Thread Bindings** Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID). -2. **Allocate Warp-local Fragments** +2. **Allocate Warp-local Fragments** Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like: ```python A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) @@ -406,7 +409,7 @@ def tl_matmul( ``` Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles. -3. **Load Data via `ldmatrix`** +3. **Load Data via `ldmatrix`** Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well: ```python for ki in T.serial(0, (block_K // micro_size_k)): @@ -418,7 +421,7 @@ def tl_matmul( ``` Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers. -4. **Perform the MMA Instruction** +4. **Perform the MMA Instruction** After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially: \[ C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}} @@ -429,7 +432,7 @@ def tl_matmul( ``` Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel. -5. **Store Results via `stmatrix`** +5. **Store Results via `stmatrix`** Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet: ```python mma_emitter.stmatrix(C_local, C_shared) @@ -444,6 +447,6 @@ By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with ma ## References -- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. -- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. +- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM. +- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA. - [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul. diff --git a/format.sh b/format.sh index 9b6437a27..f2efab4d3 100755 --- a/format.sh +++ b/format.sh @@ -80,6 +80,9 @@ elif [[ "${#FILES[@]}" -gt 0 ]]; then echo "Checking specified files: ${FILES[*]}..." >&2 fi +# Some systems set pip's default to --user, which breaks isolated virtualenvs. +export PIP_USER=0 + # If pre-commit is not installed, install it. if ! python3 -m pre_commit --version &>/dev/null; then python3 -m pip install pre-commit diff --git a/pyproject.toml b/pyproject.toml index af443d52b..6e3070247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,21 +8,27 @@ maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }] license = "MIT" keywords = ["BLAS", "CUDA", "HIP", "Code Generation", "TVM"] classifiers = [ + "Development Status :: 4 - Beta", "Environment :: GPU", "Operating System :: POSIX :: Linux", - "Operating System :: OS Independent", "Operating System :: MacOS", + "Programming Language :: C++", + "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Scientific/Engineering :: Artificial Intelligence", ] dynamic = ["version"] dependencies = [ + "apache-tvm-ffi~=0.1.0", "cloudpickle", "ml-dtypes", "numpy>=1.23.5", @@ -39,11 +45,7 @@ dependencies = [ fp4 = ["ml-dtypes>=0.5.1"] [build-system] -requires = [ - "cython>=3.0.0", - "scikit-build-core", - "setuptools>=63", -] +requires = ["cython>=3.0.0", "scikit-build-core"] build-backend = "scikit_build_core.build" [tool.scikit-build] @@ -170,27 +172,37 @@ build-frontend = "build" environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1" } environment-pass = [ "CUDA_VERSION", + "NO_VERSION_LABEL", + "NO_TOOLCHAIN_VERSION", + "NO_GIT_VERSION", "COLUMNS", + "CMAKE_GENERATOR", + "CMAKE_BUILD_PARALLEL_LEVEL", "FORCE_COLOR", "CLICOLOR_FORCE", ] before-build = "env -0 | sort -z | tr '\\0' '\\n'" windows.before-build = "set" -# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now -manylinux-x86_64-image = "manylinux2014" -manylinux-aarch64-image = "manylinux_2_28" +test-command = [ + "python -c 'import tilelang; print(tilelang.__version__)'", +] [tool.cibuildwheel.linux] -environment = { PYTHONDEVMODE = "1", PYTHONUNBUFFERED = "1", PATH = "/usr/local/cuda/bin:$PATH" } -repair-wheel-command = [ - "auditwheel repair --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", - "pipx run abi3audit --strict --report {wheel}", -] +environment.PYTHONDEVMODE = "1" +environment.PYTHONUNBUFFERED = "1" +environment.PATH = "/usr/local/cuda/bin:$PATH" +environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" +# Pin to glibc 2.17 for x86 and 2.28 for aarch64 for now +manylinux-x86_64-image = "manylinux2014" # CentOS 7 +manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8 # Install CUDA runtime and stub driver library # manylinux_2_28 uses gcc 14, which needs CUDA 12.8 before-all = """ set -eux +cat /etc/*-release +uname -a + case "$(uname -m)" in "x86_64") yum-config-manager --add-repo https://developer.download.nvidia.cn/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo @@ -205,5 +217,22 @@ esac cudaver="$(echo "${CUDA_VERSION:-"12.4"}" | cut -d '.' -f-2)" v="${cudaver//./-}" -yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" +yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-devel-${v}" nvidia-driver-cuda-libs """ +repair-wheel-command = [ + "auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}", + "pipx run abi3audit --verbose --strict {wheel}", +] + +[tool.cibuildwheel.macos] +repair-wheel-command = [ + "delocate-wheel --verbose --ignore-missing-dependencies --no-sanitize-rpaths --require-archs {delocate_archs} -w {dest_dir} -v {wheel}", + "pipx run abi3audit --verbose --strict {wheel}", +] + +[[tool.cibuildwheel.overrides]] +select = "*linux*x86_64*" +# CentOS 7 is too old to run import test. Do wheel installation test only. +test-command = [ + "echo 'Wheel is installed successfully'", +] diff --git a/requirements-test.txt b/requirements-test.txt index f896c4824..38bdf2d7b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -18,10 +18,11 @@ cython docutils dtlib einops +flash-linear-attention==0.3.2 packaging>=21.0 -pytest-xdist>=2.2.1 pytest-durations pytest-timeout +pytest-xdist>=2.2.1 pytest>=6.2.4 pyyaml requests diff --git a/requirements.txt b/requirements.txt index 49a398844..3ad186ed4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # Runtime requirements +apache-tvm-ffi~=0.1.0 cloudpickle ml-dtypes numpy>=1.23.5 @@ -7,4 +8,3 @@ torch torch>=2.7; platform_system == 'Darwin' tqdm>=4.62.3 typing-extensions>=4.10.0 -flash-linear-attention==0.3.2 \ No newline at end of file diff --git a/src/ir.cc b/src/ir.cc index aea1c3697..3d2b3ecdc 100644 --- a/src/ir.cc +++ b/src/ir.cc @@ -7,6 +7,9 @@ #include "./transform/common/attr.h" #include "op/builtin.h" #include "tvm/ffi/any.h" +#include + +#include "support/ffi_aliases.h" #include #include #include @@ -37,7 +40,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { using namespace tvm::tir; Var var = Var(name, dom->dtype); // Create a frame that represents a loop over the given domain. - ObjectPtr n = make_object(); + ObjectPtr n = tvm::ffi::make_object(); n->vars.push_back(var); n->doms.push_back(Range(0, dom)); n->f_make_for_loop = [](const Array &vars, const Array &doms, @@ -52,7 +55,7 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { ForFrame ParallelFor(const Array &extents, const Map &annotations) { using namespace tvm::tir; - ObjectPtr n = make_object(); + ObjectPtr n = tvm::ffi::make_object(); n->vars.reserve(extents.size()); n->doms.reserve(extents.size()); for (const auto &extent : extents) { @@ -82,7 +85,7 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, const Array> &sync, const Array> &groups) { using namespace tvm::tir; - ObjectPtr n = make_object(); + ObjectPtr n = tvm::ffi::make_object(); DataType dtype = stop.dtype(); n->vars.push_back(Var("v", dtype)); n->doms.push_back(Range(std::move(start), stop)); @@ -113,7 +116,7 @@ ForFrame PersistentFor(const Array &domain, const PrimExpr &wave_size, const PrimExpr &index, PrimExpr group_size) { using namespace tvm::tir; ICHECK(!domain.empty()); - ObjectPtr n = make_object(); + ObjectPtr n = tvm::ffi::make_object(); n->vars.reserve(domain.size()); n->doms.reserve(domain.size()); PrimExpr domain_size = domain[0]; @@ -193,8 +196,8 @@ class KernelLaunchFrameNode : public TIRFrameNode { "frames", &KernelLaunchFrameNode::frames); } - static constexpr const char *_type_key = "tl.KernelLaunchFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(KernelLaunchFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.KernelLaunchFrame", + KernelLaunchFrameNode, TIRFrameNode); public: TVM_DLL void EnterWithScope() final { @@ -218,14 +221,20 @@ class KernelLaunchFrameNode : public TIRFrameNode { */ class KernelLaunchFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(KernelLaunchFrame, TIRFrame, - KernelLaunchFrameNode); + explicit KernelLaunchFrame(ObjectPtr data) + : TIRFrame(::tvm::ffi::UnsafeInit{}) { + ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(KernelLaunchFrame, TIRFrame, + KernelLaunchFrameNode); }; KernelLaunchFrame KernelLaunch(const Array &grid_size, const Optional> &block_size_opt, const Map &attrs) { - ObjectPtr n = make_object(); + ObjectPtr n = + tvm::ffi::make_object(); // If the kernel is a CPU kernel, we don't need to launch any threads. bool is_cpu_kernel_frame = @@ -289,16 +298,14 @@ KernelLaunchFrame KernelLaunch(const Array &grid_size, return KernelLaunchFrame(n); } -TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); - -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tl.Parallel", ParallelFor) .def("tl.Pipelined", PipelinedFor) .def("tl.Persistent", PersistentFor) .def("tl.KernelLaunch", KernelLaunch); -}); +} class WarpSpecializeFrameNode : public TIRFrameNode { public: @@ -310,8 +317,8 @@ class WarpSpecializeFrameNode : public TIRFrameNode { "frames", &WarpSpecializeFrameNode::frames); } - static constexpr const char *_type_key = "tl.WarpSpecializeFrame"; - TVM_DECLARE_FINAL_OBJECT_INFO(WarpSpecializeFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.WarpSpecializeFrame", + WarpSpecializeFrameNode, TIRFrameNode); public: TVM_DLL void EnterWithScope() final { @@ -330,15 +337,20 @@ class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrame : public TIRFrame { public: - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WarpSpecializeFrame, - TIRFrame, - WarpSpecializeFrameNode); + explicit WarpSpecializeFrame(ObjectPtr data) + : TIRFrame(::tvm::ffi::UnsafeInit{}) { + ICHECK(data != nullptr); + data_ = std::move(data); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(WarpSpecializeFrame, TIRFrame, + WarpSpecializeFrameNode); }; WarpSpecializeFrame WarpSpecialize(const Array &warp_group_ids, const PrimExpr &thread_idx, int warp_group_size = 128) { - ObjectPtr n = make_object(); + ObjectPtr n = + tvm::ffi::make_object(); PrimExpr condition; std::vector warp_groups; warp_groups.reserve(warp_group_ids.size()); @@ -376,13 +388,12 @@ WarpSpecializeFrame WarpSpecialize(const Array &warp_group_ids, return WarpSpecializeFrame(n); } -TVM_REGISTER_NODE_TYPE(WarpSpecializeFrameNode); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.WarpSpecialize", WarpSpecialize); KernelLaunchFrameNode::RegisterReflection(); WarpSpecializeFrameNode::RegisterReflection(); -}); +} } // namespace tl } // namespace tvm diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 5eb4a822d..e9acfeb1c 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -64,13 +64,12 @@ Layout::Layout(Array forward_var, Array forward_index) { } forward_index = forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); - - auto n = make_object(input_size, forward_index); + auto n = tvm::ffi::make_object(input_size, forward_index); data_ = std::move(n); } Layout::Layout(Array input_size, Array forward_index) { - auto n = make_object(input_size, forward_index); + auto n = tvm::ffi::make_object(input_size, forward_index); data_ = std::move(n); } @@ -130,7 +129,6 @@ Array LayoutNode::Forward(const Array &vars) const { Array transformed = forward_index_.Map( [&](const PrimExpr &e) { return Substitute(e, vmap); }); - // Concatenate with the remaining elements from vars Array result; for (size_t i = 0; i < vars.size() - InputDim(); i++) { @@ -212,7 +210,7 @@ Fragment FragmentNode::DeReplicate() const { factor = arith::ZeroAwareGCD(*rep_size, *idx_size); } if (factor == 1) - return GetRef(this); + return tvm::ffi::GetRef(this); Map vmap; vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor + @@ -224,7 +222,7 @@ Fragment FragmentNode::DeReplicate() const { } Fragment FragmentNode::BindThreadRange(Range thread_range) const { - auto n = make_object(*this); + auto n = tvm::ffi::make_object(*this); n->thread_range_ = thread_range; return Fragment(n); } @@ -336,8 +334,8 @@ Fragment::Fragment(Array forward_var, Array forward_index, forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); forward_thread = Substitute(forward_thread, vmap); - auto n = make_object(input_size, forward_index, forward_thread, - replicate_size); + auto n = tvm::ffi::make_object(input_size, forward_index, + forward_thread, replicate_size); data_ = std::move(n); } @@ -348,8 +346,8 @@ Fragment::Fragment(Array input_size, Array forward_index, forward_thread = Substitute( forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}}); } - auto n = make_object(input_size, forward_index, forward_thread, - replicate_size); + auto n = tvm::ffi::make_object(input_size, forward_index, + forward_thread, replicate_size); data_ = std::move(n); } @@ -442,21 +440,6 @@ std::string FragmentNode::DebugOutput() const { return ss.str(); } -bool LayoutNode::SEqualReduce(const LayoutNode *other, - SEqualReducer equal) const { - return equal(this->InputShape(), other->InputShape()) && - equal(this->forward_index_, other->forward_index_); -} - -bool FragmentNode::SEqualReduce(const FragmentNode *other, - SEqualReducer equal) const { - return equal(this->ReplicateExtent(), other->ReplicateExtent()) && - equal(this->InputShape(), other->InputShape()) && - equal(this->ThreadExtent(), other->ThreadExtent()) && - equal(this->forward_index_, other->forward_index_) && - equal(this->forward_thread_, other->forward_thread_); -} - bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const { bool ret = StructuralEqual()(this->InputShape(), other->InputShape()); ret &= StructuralEqual()(this->OutputShape(), other->OutputShape()); @@ -495,10 +478,7 @@ void FragmentNode::RegisterReflection() { .def_ro("replicate_size", &FragmentNode::replicate_size_); } -TVM_REGISTER_NODE_TYPE(LayoutNode); -TVM_REGISTER_NODE_TYPE(FragmentNode); - -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def_packed("tl.Layout", @@ -582,13 +562,13 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("tl.make_linear_layout", [](int stride, int continuous) { return makeGemmLayoutLinear(stride, continuous); }); -}); +} -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; LayoutNode::RegisterReflection(); FragmentNode::RegisterReflection(); -}); +} } // namespace tl } // namespace tvm diff --git a/src/layout/layout.h b/src/layout/layout.h index 0001c803b..97fde85d3 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -8,8 +8,11 @@ #include #include +#include #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace tl { @@ -44,11 +47,10 @@ class LayoutNode : public Object { virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr const char *_type_key = "tl.Layout"; - bool SEqualReduce(const LayoutNode *other, SEqualReducer equal) const; static void RegisterReflection(); - TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = + kTVMFFISEqHashKindTreeNode; protected: virtual Map getVarMap() const; @@ -65,7 +67,7 @@ class Layout : public ObjectRef { TVM_DLL Layout(Array forward_var, Array forward_index); TVM_DLL Layout(Array input_size, Array forward_index); - TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Layout, ObjectRef, LayoutNode); }; class FragmentNode : public LayoutNode { @@ -109,9 +111,9 @@ class FragmentNode : public LayoutNode { static void RegisterReflection(); - bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; - static constexpr const char *_type_key = "tl.Fragment"; - TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode); + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = + kTVMFFISEqHashKindTreeNode; protected: Map getVarMap() const final; @@ -132,7 +134,7 @@ class Fragment : public Layout { PrimExpr forward_thread, PrimExpr replicate_size, Optional replicate_var); - TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode); }; Var InputPlaceholder(size_t idx); diff --git a/src/layout/swizzle.cc b/src/layout/swizzle.cc index 2da308038..e3222b9c0 100644 --- a/src/layout/swizzle.cc +++ b/src/layout/swizzle.cc @@ -6,6 +6,7 @@ #include "swizzle.h" +#include #include #include @@ -86,14 +87,16 @@ SwizzledLayout::SwizzledLayout(Array forward_var, forward_index = forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); }); - auto n = make_object(input_size, forward_index, pattern); + auto n = tvm::ffi::make_object(input_size, forward_index, + pattern); data_ = std::move(n); } SwizzledLayout::SwizzledLayout(Array input_size, Array forward_index, SwizzlePattern pattern) { - auto n = make_object(input_size, forward_index, pattern); + auto n = tvm::ffi::make_object(input_size, forward_index, + pattern); data_ = std::move(n); } @@ -102,14 +105,5 @@ void SwizzledLayoutNode::RegisterReflection() { refl::ObjectDef(); } -bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode *other, - SEqualReducer equal) const { - return equal(this->InputShape(), other->InputShape()) && - equal(this->forward_index_, other->forward_index_) && - pattern_ == other->pattern_; -} - -TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode); - } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/layout/swizzle.h b/src/layout/swizzle.h index 5f7f4f3dd..b0bf5f1c9 100644 --- a/src/layout/swizzle.h +++ b/src/layout/swizzle.h @@ -44,10 +44,9 @@ class SwizzledLayoutNode : public LayoutNode { Layout Inverse() const final; std::string DebugOutput() const final; bool IsEqual(const SwizzledLayoutNode *other, bool skip_index = false) const; - static constexpr const char *_type_key = "tl.SwizzledLayout"; - bool SEqualReduce(const SwizzledLayoutNode *other, SEqualReducer equal) const; static void RegisterReflection(); - TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.SwizzledLayout", SwizzledLayoutNode, + LayoutNode); private: SwizzlePattern pattern_; @@ -62,11 +61,11 @@ class SwizzledLayout : public Layout { Array forward_index, SwizzlePattern pattern); TVM_DLL SwizzledLayout(Array input_size, Array forward_index, SwizzlePattern pattern); - - TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(SwizzledLayout, Layout, + SwizzledLayoutNode); }; } // namespace tl } // namespace tvm -#endif // TVM_TL_LAYOUT_SWIZZLE_H_ \ No newline at end of file +#endif // TVM_TL_LAYOUT_SWIZZLE_H_ diff --git a/src/layout/utils.cc b/src/layout/utils.cc index 22849a0d8..4f533c442 100644 --- a/src/layout/utils.cc +++ b/src/layout/utils.cc @@ -189,7 +189,7 @@ class IterSumMutator { IterMark Mutate(const IterMark &mark) { if (auto *op = mark->source.as()) { - return IterMark(Mutate(GetRef(op)), mark->extent); + return IterMark(Mutate(tvm::ffi::GetRef(op)), mark->extent); } else { return mark; } diff --git a/src/layout/utils.h b/src/layout/utils.h index 87732bf97..0f03a8617 100644 --- a/src/layout/utils.h +++ b/src/layout/utils.h @@ -9,6 +9,8 @@ #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace tl { diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 31c5bfb4d..57e0d8b78 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -42,7 +42,7 @@ using namespace tir; * - The constructed node is stored in this->data_. */ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { @@ -78,7 +78,7 @@ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { * @return TileOperator A TileOperator owning the cloned AtomicAddNode. */ TileOperator AtomicAddNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); if (par_op_.defined()) { op->par_op_ = Downcast(par_op_->Clone()); } @@ -549,7 +549,7 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ AtomicAddNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { AtomicAddNode::RegisterReflection(); } } // namespace tl } // namespace tvm \ No newline at end of file diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index ae9cc99af..f3aaacdbe 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -25,8 +25,8 @@ class AtomicAddNode : public TileOperatorNode { IntImm memory_order; ///< Memory order for atomic operations mutable ParallelOp par_op_; ///< Associated parallel operation - static constexpr const char *_type_key = "tl.AtomicAdd"; - TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode, + TileOperatorNode); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; @@ -46,28 +46,6 @@ class AtomicAddNode : public TileOperatorNode { .def_ro("memory_order", &AtomicAddNode::memory_order); } - bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const { - return equal(src, other->src) && equal(dst, other->dst) && - equal(src_range, other->src_range) && - equal(dst_range, other->dst_range) && - equal(use_tma, other->use_tma) && - equal(coalesced_width, other->coalesced_width) && - equal(memory_order, other->memory_order); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(src); - hash_reduce(dst); - hash_reduce(src_range); - hash_reduce(dst_range); - hash_reduce(use_tma); - hash_reduce(coalesced_width); - hash_reduce(memory_order); - } - - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - protected: /// Create SIMT-style parallel loop structure For MakeSIMTLoop(arith::Analyzer *analyzer) const; @@ -85,7 +63,8 @@ class AtomicAddNode : public TileOperatorNode { /// Wrapper class for atomic addition operations class AtomicAdd : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, + AtomicAddNode); TVM_DLL AtomicAdd(Array args, BufferMap vmap); static const Op &Get(); }; @@ -93,4 +72,4 @@ class AtomicAdd : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_ATOMIC_ADD_H_ \ No newline at end of file +#endif // TVM_TL_OP_ATOMIC_ADD_H_ diff --git a/src/op/copy.cc b/src/op/copy.cc index 754dd7336..275af38ba 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -130,7 +130,7 @@ template static Array ReverseArray(Array array) { * @param vmap BufferMap used to resolve RegionOp buffers and ranges. */ Copy::Copy(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { @@ -169,7 +169,7 @@ Copy::Copy(Array args, BufferMap vmap) { * @return TileOperator A TileOperator owning the cloned CopyNode. */ TileOperator CopyNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); if (par_op_.defined()) { op->par_op_ = Downcast(par_op_->Clone()); } @@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = - pass_ctx->GetConfig(kDisableTMALower, false).value(); + pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, T.analyzer, T.buffer_oob); @@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = - pass_ctx->GetConfig(kDisableTMALower, false).value(); + pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, T.layout_map, analyzer); if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { @@ -1722,7 +1722,8 @@ Array TMADesc::EncodeCallArgs() const { * @param vmap Mapping from original buffer variables to actual Buffer objects. */ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = + tvm::ffi::make_object(); node->src = vmap[GetVarFromAccessPtr(args[0])]; node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->nhw_step = args[2]; @@ -1747,7 +1748,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. */ TileOperator Conv2DIm2ColOpNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return Conv2DIm2ColOp(op); } @@ -1973,9 +1974,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { CopyNode::RegisterReflection(); Conv2DIm2ColOpNode::RegisterReflection(); -}); +} } // namespace tl } // namespace tvm diff --git a/src/op/copy.h b/src/op/copy.h index 00d07f169..ef46b9edb 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -101,8 +101,7 @@ class CopyNode : public TileOperatorNode { }; uint8_t eviction_policy; // Policy for cache eviction - static constexpr const char *_type_key = "tl.Copy"; - TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -114,23 +113,6 @@ class CopyNode : public TileOperatorNode { .def_ro("coalesced_width", &CopyNode::coalesced_width); } - bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const { - return equal(src, other->src) && equal(dst, other->dst) && - equal(src_range, other->src_range) && - equal(dst_range, other->dst_range) && - equal(coalesced_width, other->coalesced_width); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(src); - hash_reduce(dst); - hash_reduce(src_range); - hash_reduce(dst_range); - hash_reduce(coalesced_width); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - /*! * \brief Lower the copy operator to a TIR statement. * \param T Arguments for lowering. @@ -291,7 +273,7 @@ class CopyNode : public TileOperatorNode { class Copy : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Copy, TileOperator, CopyNode); /*! * \brief Constructor. @@ -323,8 +305,8 @@ class Conv2DIm2ColOpNode : public TileOperatorNode { PrimExpr nhw_step; // Step size in NHW dimensions PrimExpr c_step; // Step size in channel dimension - static constexpr const char *_type_key = "tl.Conv2DIm2Col"; - TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode, + TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -338,26 +320,6 @@ class Conv2DIm2ColOpNode : public TileOperatorNode { .def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy); } - bool SEqualReduce(const Conv2DIm2ColOpNode *other, - SEqualReducer equal) const { - return equal(src, other->src) && equal(dst, other->dst) && - equal(stride, other->stride) && equal(padding, other->padding) && - equal(dilation, other->dilation) && equal(kernel, other->kernel) && - equal(eviction_policy, other->eviction_policy); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(src); - hash_reduce(dst); - hash_reduce(stride); - hash_reduce(padding); - hash_reduce(dilation); - hash_reduce(kernel); - hash_reduce(eviction_policy); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - /*! * \brief Lower to TIR statement. */ @@ -378,8 +340,8 @@ class Conv2DIm2ColOpNode : public TileOperatorNode { class Conv2DIm2ColOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator, - Conv2DIm2ColOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, + Conv2DIm2ColOpNode); TVM_DLL Conv2DIm2ColOp(Array args, BufferMap vmap); static const Op &Get(); }; @@ -387,4 +349,4 @@ class Conv2DIm2ColOp : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_COPY_H_ \ No newline at end of file +#endif // TVM_TL_OP_COPY_H_ diff --git a/src/op/fill.cc b/src/op/fill.cc index 8f0dec63b..055e64053 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -60,7 +60,7 @@ using namespace tir; * of bounds. */ Fill::Fill(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); if (args[0]->IsInstance()) { auto buffer_load = Downcast(args[0]); @@ -117,7 +117,7 @@ Fill::Fill(Array args, BufferMap vmap) { * @return TileOperator A TileOperator that owns the copied FillNode. */ TileOperator FillNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return Fill(op); } @@ -226,7 +226,7 @@ TIR_REGISTER_TL_OP(Fill, fill) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); } } // namespace tl } // namespace tvm \ No newline at end of file diff --git a/src/op/fill.h b/src/op/fill.h index 6d3840763..8f1dd9006 100644 --- a/src/op/fill.h +++ b/src/op/fill.h @@ -20,8 +20,7 @@ class FillNode : public TileOperatorNode { tir::Buffer dst; ///< Destination buffer to fill PrimExpr value; ///< Value to fill with Array region; ///< Region to fill within the buffer - static constexpr const char *_type_key = "tl.Fill"; - TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fill", FillNode, TileOperatorNode); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; @@ -35,19 +34,6 @@ class FillNode : public TileOperatorNode { .def_ro("region", &FillNode::region); } - bool SEqualReduce(const FillNode *other, SEqualReducer equal) const { - return equal(dst, other->dst) && equal(value, other->value) && - equal(region, other->region); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dst); - hash_reduce(value); - hash_reduce(region); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - TileOperator Clone() const; private: @@ -58,7 +44,7 @@ class FillNode : public TileOperatorNode { /// Wrapper class for fill operations class Fill : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); TVM_DLL Fill(Array args, BufferMap vmap); static const Op &Get(); }; @@ -66,4 +52,4 @@ class Fill : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_FILL_H_ \ No newline at end of file +#endif // TVM_TL_OP_FILL_H_ diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index def940b4b..84b18897b 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -33,7 +33,7 @@ using namespace tir; * Buffer. */ FinalizeReducerOp::FinalizeReducerOp(Array args, BufferMap vmap) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->reducer = vmap[GetVarFromAccessPtr(args[0])]; node->op = (ReducerOpType)*as_const_int(args[1]); data_ = std::move(node); @@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T, * @return TileOperator A TileOperator that contains a deep copy of this node. */ TileOperator FinalizeReducerOpNode::Clone() const { - auto node = make_object(*this); + auto node = tvm::ffi::make_object(*this); return TileOperator(node); } @@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ FinalizeReducerOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { FinalizeReducerOpNode::RegisterReflection(); } } // namespace tl } // namespace tvm diff --git a/src/op/finalize_reducer.h b/src/op/finalize_reducer.h index d9a66d1b9..ef49ee194 100644 --- a/src/op/finalize_reducer.h +++ b/src/op/finalize_reducer.h @@ -27,8 +27,8 @@ class FinalizeReducerOpNode : public TileOperatorNode { tir::Buffer reducer; ReducerOpType op; - static constexpr const char *_type_key = "tl.FinalizeReducerOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.FinalizeReducerOp", + FinalizeReducerOpNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -37,18 +37,6 @@ class FinalizeReducerOpNode : public TileOperatorNode { .def_ro("op", &FinalizeReducerOpNode::op); } - bool SEqualReduce(const FinalizeReducerOpNode *other, - SEqualReducer equal) const { - return equal(reducer, other->reducer) && equal(op, other->op); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(reducer); - hash_reduce(op); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; @@ -58,8 +46,8 @@ class FinalizeReducerOpNode : public TileOperatorNode { class FinalizeReducerOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator, - FinalizeReducerOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, + FinalizeReducerOpNode); TVM_DLL FinalizeReducerOp(Array args, BufferMap vmap); static const Op &Get(); }; @@ -67,4 +55,4 @@ class FinalizeReducerOp : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_FINALIZE_REDUCER_H_ \ No newline at end of file +#endif // TVM_TL_OP_FINALIZE_REDUCER_H_ diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 8912a7a33..e0077bb34 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -112,7 +112,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { * performed here. */ Gemm::Gemm(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->Aptr = args[0]; node->Bptr = args[1]; @@ -160,7 +160,7 @@ Gemm::Gemm(Array args, BufferMap vmap) { * @return TileOperator A Gemm operator that owns a copy of this node. */ TileOperator GemmNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return Gemm(op); } @@ -476,8 +476,8 @@ bool GemmNode::CheckWGMMA() const { */ static int GetArchInt(Target target) { int arch_int = 0; - auto s = target->GetAttr("arch"); - ICHECK(s.defined()); + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); std::string arch = s.value(); if (arch.rfind("sm_", 0) == 0) { arch_int = std::stoi(arch.substr(3)); @@ -874,7 +874,7 @@ TIR_REGISTER_TL_OP(Gemm, gemm) TVM_REGISTER_OP("tl.GemmWarpPolicy") .set_attr("TScriptPrinterName", "GemmWarpPolicy"); -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { GemmNode::RegisterReflection(); GemmWarpPolicyNode::RegisterReflection(); namespace refl = tvm::ffi::reflection; @@ -883,9 +883,8 @@ TVM_FFI_STATIC_INIT_BLOCK({ Target target, GemmInst gemm_inst) { policy->ComputeWarpPartition(M, N, block_size, target, gemm_inst); - return; }); -}); +} } // namespace tl } // namespace tvm diff --git a/src/op/gemm.h b/src/op/gemm.h index dd7e24011..66cf9e2e0 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -30,8 +30,7 @@ class GemmWarpPolicyNode : public Object { mutable int n_warp{0}; int policy_type; - static constexpr const char *_type_key = "tl.GemmWarpPolicy"; - TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -41,21 +40,6 @@ class GemmWarpPolicyNode : public Object { .def_ro("n_warp", &GemmWarpPolicyNode::n_warp); } - bool SEqualReduce(const GemmWarpPolicyNode *other, - SEqualReducer equal) const { - return equal(policy_type, other->policy_type) && - equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(policy_type); - hash_reduce(m_warp); - hash_reduce(n_warp); - } - - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - std::pair ComputeWarpPartition(int M, int N, int block_size, Target target, GemmInst gemm_inst) const; @@ -74,22 +58,23 @@ class GemmWarpPolicyNode : public Object { class GemmWarpPolicy : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef, + GemmWarpPolicyNode); explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->policy_type = (int)policy_type; data_ = std::move(node); } explicit GemmWarpPolicy(int policy_type) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->policy_type = policy_type; data_ = std::move(node); } explicit GemmWarpPolicy(int m_warp, int n_warp) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->m_warp = m_warp; node->n_warp = n_warp; node->policy_type = (int)GemmWarpPolicyType::kFree; @@ -116,9 +101,7 @@ class GemmNode : public TileOperatorNode { std::optional mbar; // mbar is optional, only used for TCGEN5MMA Array C_coords; mutable GemmWarpPolicy policy; - - static constexpr const char *_type_key = "tl.Gemm"; - TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -144,45 +127,6 @@ class GemmNode : public TileOperatorNode { .def_ro("policy", &GemmNode::policy); } - bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const { - return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && - equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) && - equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && - equal(trans_B, other->trans_B) && equal(M, other->M) && - equal(N, other->N) && equal(K, other->K) && - equal(stride_A, other->stride_A) && - equal(stride_B, other->stride_B) && - equal(offset_A, other->offset_A) && - equal(offset_B, other->offset_B) && - equal(clear_accum, other->clear_accum) && - equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && - equal(policy, other->policy); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(A); - hash_reduce(B); - hash_reduce(C); - hash_reduce(Aptr); - hash_reduce(Bptr); - hash_reduce(Cptr); - hash_reduce(trans_A); - hash_reduce(trans_B); - hash_reduce(M); - hash_reduce(N); - hash_reduce(K); - hash_reduce(stride_A); - hash_reduce(stride_B); - hash_reduce(offset_A); - hash_reduce(offset_B); - hash_reduce(clear_accum); - hash_reduce(kPack); - hash_reduce(wg_wait); - hash_reduce(policy); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; @@ -199,7 +143,7 @@ class GemmNode : public TileOperatorNode { class Gemm : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); TVM_DLL Gemm(Array args, BufferMap vmap); static const Op &Get(); }; @@ -207,4 +151,4 @@ class Gemm : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_GEMM_H_ \ No newline at end of file +#endif // TVM_TL_OP_GEMM_H_ diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 4e48389ee..3641cf0b1 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -11,8 +11,8 @@ #include #include +#include "../support/ffi_aliases.h" #include "../target/utils.h" -#include "tvm/ffi/string.h" namespace tvm { namespace tl { @@ -48,7 +48,7 @@ using namespace tir; * performed here. */ GemmPy::GemmPy(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->Aptr = args[0]; node->Bptr = args[1]; @@ -88,7 +88,7 @@ GemmPy::GemmPy(Array args, BufferMap vmap) { * @return TileOperator A Gemm operator that owns a copy of this node. */ TileOperator GemmPyNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return GemmPy(op); } @@ -208,8 +208,8 @@ bool GemmPyNode::CheckWGMMA() const { */ static int GetArchInt(Target target) { int arch_int = 0; - auto s = target->GetAttr("arch"); - ICHECK(s.defined()); + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); std::string arch = s.value(); if (arch.rfind("sm_", 0) == 0) { arch_int = std::stoi(arch.substr(3)); @@ -228,11 +228,12 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { auto prim_func = - Downcast((*f)(GetRef(this), T.layout_map, T.target, - T.thread_bounds, T.thread_var)); + Downcast((*f)(tvm::ffi::GetRef(this), T.layout_map, + T.target, T.thread_bounds, T.thread_var)); ICHECK(prim_func->attrs.defined()); - auto global_symbol = prim_func->attrs.GetAttr("global_symbol"); - ICHECK(global_symbol.defined()); + auto global_symbol = + prim_func->attrs.GetAttr("global_symbol"); + ICHECK(global_symbol.has_value()); if (prim_func->body.as()) { BlockRealize block_realize = Downcast(prim_func->body); auto block = block_realize->block; @@ -265,7 +266,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) { results = Downcast( - (*f)(GetRef(this), T.target, T.thread_bounds)); + (*f)(tvm::ffi::GetRef(this), T.target, T.thread_bounds)); } else { LOG(FATAL) << "No infer layout function found for gemm_py"; } @@ -279,15 +280,15 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.GemmPyGemmInst", [](GemmPy gemm_py, int block_size, Target target) { return gemm_py->GetGemmInst(block_size, target); }); -}); +} } // namespace tl } // namespace tvm diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index 65ed08c0f..499efb6d9 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -33,8 +33,7 @@ class GemmPyNode : public TileOperatorNode { int wg_wait = 0; mutable GemmWarpPolicy policy; - static constexpr const char *_type_key = "tl.GemmPy"; - TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -60,45 +59,6 @@ class GemmPyNode : public TileOperatorNode { .def_ro("policy", &GemmPyNode::policy); } - bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const { - return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && - equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) && - equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && - equal(trans_B, other->trans_B) && equal(M, other->M) && - equal(N, other->N) && equal(K, other->K) && - equal(stride_A, other->stride_A) && - equal(stride_B, other->stride_B) && - equal(offset_A, other->offset_B) && - equal(offset_B, other->offset_B) && - equal(clear_accum, other->clear_accum) && - equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && - equal(policy, other->policy); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(A); - hash_reduce(B); - hash_reduce(C); - hash_reduce(Aptr); - hash_reduce(Bptr); - hash_reduce(Cptr); - hash_reduce(trans_A); - hash_reduce(trans_B); - hash_reduce(M); - hash_reduce(N); - hash_reduce(K); - hash_reduce(stride_A); - hash_reduce(stride_B); - hash_reduce(offset_A); - hash_reduce(offset_B); - hash_reduce(clear_accum); - hash_reduce(kPack); - hash_reduce(wg_wait); - hash_reduce(policy); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; @@ -114,7 +74,7 @@ class GemmPyNode : public TileOperatorNode { class GemmPy : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); TVM_DLL GemmPy(Array args, BufferMap vmap); static const Op &Get(); }; @@ -122,4 +82,4 @@ class GemmPy : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_GEMM_PY_H_ \ No newline at end of file +#endif // TVM_TL_OP_GEMM_PY_H_ diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index dfa58b353..a23d9a552 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -84,7 +84,7 @@ std::pair GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N, * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. */ GemmSP::GemmSP(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->A = vmap[GetVarFromAccessPtr(args[0])]; node->E = vmap[GetVarFromAccessPtr(args[1])]; node->B = vmap[GetVarFromAccessPtr(args[2])]; @@ -118,7 +118,7 @@ GemmSP::GemmSP(Array args, BufferMap vmap) { * @return TileOperator A TileOperator holding a cloned GemmSPNode. */ TileOperator GemmSPNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return GemmSP(op); } @@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_FFI_STATIC_INIT_BLOCK({ GemmSPNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); } } // namespace tl } // namespace tvm diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index eee7cd795..4c6d1e25a 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -21,27 +21,29 @@ class GemmSPWarpPolicyNode : public GemmWarpPolicyNode { std::pair ComputeWarpPartition(int M, int N, int block_size, Target target, bool use_wgmma, int bits) const; + TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, + GemmWarpPolicyNode); }; class GemmSPWarpPolicy : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(GemmSPWarpPolicy, ObjectRef, - GemmSPWarpPolicyNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPWarpPolicy, ObjectRef, + GemmSPWarpPolicyNode); explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->policy_type = (int)policy_type; data_ = std::move(node); } explicit GemmSPWarpPolicy(int policy_type) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->policy_type = policy_type; data_ = std::move(node); } explicit GemmSPWarpPolicy(int m_warp, int n_warp) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); node->m_warp = m_warp; node->n_warp = n_warp; node->policy_type = (int)GemmWarpPolicyType::kFree; @@ -62,8 +64,7 @@ class GemmSPNode : public TileOperatorNode { mutable GemmSPWarpPolicy policy; - static constexpr const char *_type_key = "tl.GemmSP"; - TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; @@ -88,38 +89,13 @@ class GemmSPNode : public TileOperatorNode { .def_ro("wg_wait", &GemmSPNode::wg_wait); } - bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const { - return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && - equal(E, other->E) && equal(trans_A, other->trans_A) && - equal(trans_B, other->trans_B) && equal(M, other->M) && - equal(N, other->N) && equal(K, other->K) && - equal(clear_accum, other->clear_accum) && - equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(policy); - hash_reduce(A); - hash_reduce(B); - hash_reduce(C); - hash_reduce(E); - hash_reduce(trans_A); - hash_reduce(trans_B); - hash_reduce(M); - hash_reduce(N); - hash_reduce(K); - hash_reduce(clear_accum); - hash_reduce(kPack); - hash_reduce(wg_wait); - } - private: mutable bool completed_ = false; }; class GemmSP : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); TVM_DLL GemmSP(Array args, BufferMap vmap); static const Op &Get(); }; diff --git a/src/op/logical.cc b/src/op/logical.cc index 0398c38c1..0de6658bd 100644 --- a/src/op/logical.cc +++ b/src/op/logical.cc @@ -9,6 +9,8 @@ #include #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace tl { using namespace tir; @@ -50,4 +52,4 @@ TVM_REGISTER_OP("tl.all_of") .set_attr("cuda.FLowerIntrinsic", all_of_op); } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/op/math.cc b/src/op/math.cc index 572399877..526ea557c 100644 --- a/src/op/math.cc +++ b/src/op/math.cc @@ -9,6 +9,8 @@ #include #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace tl { using namespace tir; diff --git a/src/op/operator.cc b/src/op/operator.cc index aa589460b..b751559c7 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { if (stmt.as() && stmt.as()->value.as()) { auto call = stmt.as()->value.as(); - return ParseOperator(GetRef(call), vmap); + return ParseOperator(tvm::ffi::GetRef(call), vmap); } return TileOperator(); } @@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) { ICHECK(call->op.same_as(builtin::tvm_access_ptr())); auto var = call->args[1].as(); ICHECK(var); - return GetRef(var); + return tvm::ffi::GetRef(var); } } // namespace tl diff --git a/src/op/operator.h b/src/op/operator.h index 5c1b223ac..e3a70dae2 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -62,14 +62,13 @@ class TileOperatorNode : public Object { virtual TileOperator Clone() const = 0; - static constexpr const char *_type_key = "tl.TileOperator"; - - TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object); }; class TileOperator : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileOperator, ObjectRef, + TileOperatorNode); }; Var GetVarFromAccessPtr(const PrimExpr &expr); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index c0ef00cc8..118a9e74b 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { } TileOperator ParallelOpNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return ParallelOp(op); } @@ -642,7 +642,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ->CondenseReplicateVar(); } -TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); } } // namespace tl } // namespace tvm diff --git a/src/op/parallel.h b/src/op/parallel.h index 9c6b7180f..8ebd7366e 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -66,8 +66,8 @@ class ParallelOpNode : public TileOperatorNode { mutable Optional predicate_; // Type key for TVM object system. - static constexpr const char *_type_key = "tl.ParallelOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode, + TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -77,20 +77,6 @@ class ParallelOpNode : public TileOperatorNode { .def_ro("predicate", &ParallelOpNode::predicate_); } - bool SEqualReduce(const ParallelOpNode *other, SEqualReducer equal) const { - return equal(root_, other->root_) && - equal(loop_layout_, other->loop_layout_) && - equal(predicate_, other->predicate_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(root_); - hash_reduce(loop_layout_); - hash_reduce(predicate_); - } - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - // Construct from a root For loop. ParallelOpNode(For root); @@ -150,10 +136,11 @@ class ParallelOpNode : public TileOperatorNode { class ParallelOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator, + ParallelOpNode); ParallelOp(const For &root) { - auto op = make_object(root); + auto op = tvm::ffi::make_object(root); data_ = std::move(op); } }; diff --git a/src/op/reduce.cc b/src/op/reduce.cc index fe49e00b6..3e31aa2f1 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -22,7 +22,7 @@ namespace tl { using namespace tir; ReduceOp::ReduceOp(Array args, BufferMap vmap) { - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->src = vmap[GetVarFromAccessPtr(args[0])]; node->dst = vmap[GetVarFromAccessPtr(args[1])]; std::string reduce_type = args[2].as().value()->value; @@ -33,12 +33,12 @@ ReduceOp::ReduceOp(Array args, BufferMap vmap) { } TileOperator ReduceOpNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return ReduceOp(op); } TileOperator CumSumOpNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return CumSumOp(op); } @@ -85,6 +85,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const { return make_zero(dst->dtype); } else { LOG(FATAL) << "Unsupported reduce type: " << type->type; + return PrimExpr(); } } @@ -512,7 +513,7 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { /// - dim: dimension to cumsum /// - reverse: whether to cumsum in reverse order CHECK_EQ(args.size(), 4); - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->src = vmap[GetVarFromAccessPtr(args[0])]; node->dst = vmap[GetVarFromAccessPtr(args[1])]; node->dim = args[2].as().value()->value; @@ -567,5 +568,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + ReduceOpNode::RegisterReflection(); + CumSumOpNode::RegisterReflection(); + ReduceTypeNode::RegisterReflection(); +} + } // namespace tl } // namespace tvm diff --git a/src/op/reduce.h b/src/op/reduce.h index 853d6e0dd..93eb4bdec 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -30,23 +30,13 @@ enum class ReduceTypeEnum : uint8_t { class ReduceTypeNode : public Object { public: int type{-1}; ///< Internal type identifier - static constexpr const char *_type_key = "tl.ReduceType"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReduceTypeNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceType", ReduceTypeNode, Object); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("type", &ReduceTypeNode::type); } - bool SEqualReduce(const ReduceTypeNode *other, SEqualReducer equal) const { - return equal(type, other->type); - } - - void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(type); } - - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - /// Type checking methods bool isSum() const { return type == int(ReduceTypeEnum::kSum); } bool isAbsSum() const { return type == int(ReduceTypeEnum::kAbsSum); } @@ -61,9 +51,10 @@ class ReduceTypeNode : public Object { /// Wrapper class for reduction type with string-based construction class ReduceType : public ObjectRef { public: - TVM_DEFINE_OBJECT_REF_METHODS(ReduceType, ObjectRef, ReduceTypeNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceType, ObjectRef, + ReduceTypeNode); TVM_DLL ReduceType(std::string type) { - auto node = make_object(); + auto node = tvm::ffi::make_object(); if (type == "sum") { node->type = int(ReduceTypeEnum::kSum); } else if (type == "abssum") { @@ -95,8 +86,8 @@ class ReduceOpNode : public TileOperatorNode { ReduceType type; ///< Type of reduction operation bool clear; ///< Whether to clear destination before reduction - static constexpr const char *_type_key = "tl.ReduceOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReduceOp", ReduceOpNode, + TileOperatorNode); static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -108,23 +99,6 @@ class ReduceOpNode : public TileOperatorNode { .def_ro("clear", &ReduceOpNode::clear); } - bool SEqualReduce(const ReduceOpNode *other, SEqualReducer equal) const { - return equal(src, other->src) && equal(dst, other->dst) && - equal(dim, other->dim) && equal(type, other->type) && - equal(clear, other->clear); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(src); - hash_reduce(dst); - hash_reduce(dim); - hash_reduce(type); - hash_reduce(clear); - } - - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; - /// Lower the operator to TIR statements Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; /// Infer memory layout for buffers @@ -145,7 +119,8 @@ class ReduceOpNode : public TileOperatorNode { /// Wrapper class for reduction operations class ReduceOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator, + ReduceOpNode); TVM_DLL ReduceOp(Array args, BufferMap vmap); static const Op &Get(); }; @@ -156,8 +131,17 @@ class CumSumOpNode : public TileOperatorNode { tir::Buffer src, dst; ///< Source and destination buffers int dim; ///< Dimension along which to compute cumulative sum bool reverse; ///< Whether to compute in reverse order - static constexpr const char *_type_key = "tl.CumSumOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, + TileOperatorNode); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &CumSumOpNode::src) + .def_ro("dst", &CumSumOpNode::dst) + .def_ro("dim", &CumSumOpNode::dim) + .def_ro("reverse", &CumSumOpNode::reverse); + } Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, @@ -169,7 +153,8 @@ class CumSumOpNode : public TileOperatorNode { /// Wrapper class for cumulative sum operations class CumSumOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator, + CumSumOpNode); TVM_DLL CumSumOp(Array args, BufferMap vmap); static const Op &Get(); }; @@ -177,4 +162,4 @@ class CumSumOp : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_REDUCE_H_ \ No newline at end of file +#endif // TVM_TL_OP_REDUCE_H_ diff --git a/src/op/region.cc b/src/op/region.cc index 95a0b4295..e4984af13 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -44,7 +44,7 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { PrimExpr extent = args[2 + i]; ranges.push_back(Range::FromMinExtent(min, extent)); } - ObjectPtr node = make_object(); + ObjectPtr node = tvm::ffi::make_object(); node->buffer_ = load->buffer; node->access_mask_ = static_cast(*as_const_int(args[1])); node->ranges_ = ranges; @@ -57,7 +57,7 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { * @return TileOperator A new TileOperator that owns a copied RegionOpNode. */ TileOperator RegionOpNode::Clone() const { - auto op = make_object(*this); + auto op = tvm::ffi::make_object(*this); return RegionOp(op); } @@ -118,5 +118,7 @@ TIR_REGISTER_TL_OP(RegionOp, region) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); +TVM_FFI_STATIC_INIT_BLOCK() { RegionOpNode::RegisterReflection(); } + } // namespace tl } // namespace tvm diff --git a/src/op/region.h b/src/op/region.h index 2d3c9d8ec..e5c478bff 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -80,8 +80,8 @@ class RegionOpNode : public TileOperatorNode { Array ranges_; int access_mask_; - static constexpr const char *_type_key = "tl.RegionOp"; - TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode, + TileOperatorNode); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, @@ -101,25 +101,12 @@ class RegionOpNode : public TileOperatorNode { .def_ro("ranges", &RegionOpNode::ranges_) .def_ro("access_mask", &RegionOpNode::access_mask_); } - - bool SEqualReduce(const RegionOpNode *other, SEqualReducer equal) const { - return equal(buffer_, other->buffer_) && equal(ranges_, other->ranges_) && - equal(access_mask_, other->access_mask_); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer_); - hash_reduce(ranges_); - hash_reduce(access_mask_); - } - - static constexpr bool _type_has_method_sequal_reduce = true; - static constexpr bool _type_has_method_shash_reduce = true; }; class RegionOp : public TileOperator { public: - TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator, + RegionOpNode); TVM_DLL RegionOp(Array args, BufferMap vmap); static const Op &Get(); diff --git a/src/runtime/runtime.cc b/src/runtime/runtime.cc index 3ea89d666..a00786e25 100644 --- a/src/runtime/runtime.cc +++ b/src/runtime/runtime.cc @@ -89,7 +89,7 @@ struct TensorMapArgs { }; // set device api -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) { @@ -104,7 +104,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } *ret = static_cast(result); }); -}); +} struct TensorMapIm2ColArgs { CUtensorMap *map; @@ -180,7 +180,7 @@ struct TensorMapIm2ColArgs { } }; -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( "tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) { @@ -197,7 +197,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ } *ret = static_cast(result); }); -}); +} #endif // (CUDA_MAJOR_VERSION >= 12) diff --git a/src/support/ffi_aliases.h b/src/support/ffi_aliases.h new file mode 100644 index 000000000..cbc6fb027 --- /dev/null +++ b/src/support/ffi_aliases.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace tvm { +using ffi::Array; +using ffi::Function; +using ffi::Map; +using ffi::Optional; +using ffi::String; +} // namespace tvm diff --git a/src/target/codegen_cpp.cc b/src/target/codegen_cpp.cc index a2c52cad9..9accf5303 100644 --- a/src/target/codegen_cpp.cc +++ b/src/target/codegen_cpp.cc @@ -29,6 +29,7 @@ #include #include +#include "../support/ffi_aliases.h" #include "support/str_escape.h" #include "target/build_common.h" #include "target/source/codegen_params.h" @@ -54,8 +55,7 @@ void CodeGenTileLangCPP::Init(bool output_ssa, bool emit_asserts, } void CodeGenTileLangCPP::InitGlobalContext() { - decl_stream << "void* " << tvm::runtime::symbol::tvm_ffi_library_ctx - << " = NULL;\n"; + decl_stream << "void* " << ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } void CodeGenTileLangCPP::DefineModuleName() { @@ -256,8 +256,8 @@ void CodeGenTileLangCPP::AddFunction(const PrimFunc &f) { // reserve keywords ReserveKeywordsAsUnique(); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); diff --git a/src/target/codegen_cpp.h b/src/target/codegen_cpp.h index c3ce25a0a..25bb115c8 100644 --- a/src/target/codegen_cpp.h +++ b/src/target/codegen_cpp.h @@ -73,10 +73,10 @@ class CodeGenTileLangCPP : public CodeGenC { void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) void VisitStmt_(const AllocateNode *op) final; // NOLINT(*) - void GenerateForwardFunctionDeclarations(String global_symbol, - const Array &arg_types, + void GenerateForwardFunctionDeclarations(ffi::String global_symbol, + const ffi::Array &arg_types, const Type &ret_type) override; - Array GetFunctionNames() { return function_names_; } + ffi::Array GetFunctionNames() { return function_names_; } private: /* \brief Internal structure to store information about function calls */ @@ -92,7 +92,7 @@ class CodeGenTileLangCPP : public CodeGenC { /* \brief mapping global packed func to the unique name */ std::unordered_map declared_globals_; /* \brief names of the functions declared in this module */ - Array function_names_; + ffi::Array function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; /*! \brief whether to emit forward function declarations in the resulting C diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index fc06cb99a..189faa29b 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -20,6 +20,7 @@ namespace tvm { namespace codegen { using namespace tvm::tl::codegen; +using namespace ffi; struct CUDAMath { std::string operator()(DataType t, std::string name) const { @@ -2069,8 +2070,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { "A_ptr, B_ptr, C_ptr>, but got " << op->args.size(); auto op_instance = Downcast(op->args[0]); - this->PrintCallExtern(GetType(GetRef(op)), op_instance->value, - op->args, true, os); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments args.size(); auto op_instance = Downcast(op->args[0]); enable_sparse_gemm_ = true; - this->PrintCallExtern(GetType(GetRef(op)), op_instance->value, - op->args, true, os); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::get_lane_idx())) { ICHECK_LE(op->args.size(), 1) << "tl.get_lane_idx expects at most one argument ."; @@ -2352,8 +2353,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { int lanes = static_cast(Downcast(op->lanes)->value); - CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef(op) << " with " - << lanes << " lanes is not allowed."; + CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef(op) + << " with " << lanes << " lanes is not allowed."; os << "(make_"; PrintType(op->dtype, os); os << "("; @@ -2865,7 +2866,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index d4e8121b3..66a03bc0e 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -60,14 +60,14 @@ class CodeGenTileLangCUDA final : public CodeGenC { // Override this as a work around for __grid_constant__ parameter void AddFunction(const GlobalVar &gvar, const PrimFunc &f); - void PrintFunctionSignature(const String &function_name, const PrimFunc &func, - std::ostream &os); + void PrintFunctionSignature(const ffi::String &function_name, + const PrimFunc &func, std::ostream &os); protected: virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, PrimExpr index) final; - void PrintCallExtern(Type ret_type, String global_symbol, - const Array &args, bool skip_first_arg, + void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, std::ostream &os) final; // NOLINT(*) private: diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index 9c145750d..2cfb7a594 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -959,8 +959,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { "A_ptr, B_ptr, C_ptr>, but got " << op->args.size(); auto op_instance = Downcast(op->args[0]); - this->PrintCallExtern(GetType(GetRef(op)), op_instance->value, - op->args, true, os); + this->PrintCallExtern(GetType(tvm::ffi::GetRef(op)), + op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_gemm_sp())) { LOG(FATAL) << "tl_gemm_sp is not supported on HIP"; } else if (op->op.same_as(tl::loop_break())) { @@ -1309,7 +1309,7 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) { ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) + ICHECK(global_symbol.has_value()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); diff --git a/src/target/codegen_hip.h b/src/target/codegen_hip.h index 491040be3..631050feb 100644 --- a/src/target/codegen_hip.h +++ b/src/target/codegen_hip.h @@ -56,8 +56,8 @@ class CodeGenTileLangHIP final : public CodeGenC { protected: virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, PrimExpr index) final; - void PrintCallExtern(Type ret_type, String global_symbol, - const Array &args, bool skip_first_arg, + void PrintCallExtern(Type ret_type, ffi::String global_symbol, + const ffi::Array &args, bool skip_first_arg, std::ostream &os) final; // NOLINT(*) private: diff --git a/src/target/codegen_webgpu.cc b/src/target/codegen_webgpu.cc deleted file mode 100644 index 1d64ccbc6..000000000 --- a/src/target/codegen_webgpu.cc +++ /dev/null @@ -1,786 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file codegen_webgpu.cc - */ -#include "codegen_webgpu.h" -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "arith/pattern_match.h" -#include "runtime/meta_data.h" -#include "runtime/thread_storage_scope.h" -#include "target/build_common.h" - -namespace tvm { -namespace codegen { - -// WebGPU Info -struct WebGPUWorkGroupInfo { - int workgroup_size[3] = {1, 1, 1}; - // whether we have ref to block index z is used. - bool has_block_index_z{false}; - // set of handles that have write access - std::unordered_set write_access_set; -}; - -class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { -public: - static WebGPUWorkGroupInfo Collect(const Stmt &stmt) { - WebGPUWorkgroupInfoCollector collector; - collector(stmt); - return collector.info_; - } - -private: - void VisitExpr_(const VarNode *op) final { - StmtExprVisitor::VisitExpr_(op); - Var buffer_var = GetRef(op); - if (buffer_var.dtype().is_handle()) { - info_.write_access_set.insert(buffer_var); - } - } - - void VisitStmt_(const BufferStoreNode *op) final { - StmtExprVisitor::VisitStmt_(op); - info_.write_access_set.insert(op->buffer->data); - } - - void VisitStmt_(const AttrStmtNode *op) final { - // record workgroup size - if (op->attr_key == tir::attr::thread_extent) { - IterVar iv = Downcast(op->node); - if (!iv->thread_tag.empty()) { - runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); - if (ts.rank == 1) { - ICHECK_GE(ts.dim_index, 0) - << "vthread should have been optimized out by here"; - ICHECK_LT(ts.dim_index, 3); - auto *sizeptr = op->value.as(); - ICHECK(sizeptr) << "CodeGenTileLangWebGPU: only allows constant " - "thread group size " - << " get " << op->value; - info_.workgroup_size[ts.dim_index] = - static_cast(sizeptr->value); - } else if (ts.rank == 0) { - if (ts.dim_index == 2) { - info_.has_block_index_z = true; - } - } - } - } - // normal operation - StmtExprVisitor::VisitStmt_(op); - } - WebGPUWorkGroupInfo info_; -}; - -std::string CodeGenTileLangWebGPU::Finish() { - // Using f16 requires enable directive - if (enable_fp16_) { - header_stream << "enable f16;\n\n"; - } - // WebGPU WGSL doesn't support #include. - // We must explicitly include all the templates here. - return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + - stream.str(); -} - -void CodeGenTileLangWebGPU::InitFuncState(const PrimFunc &f) { - CodeGenC::InitFuncState(f); - // analyze the data; - for (Var arg : f->params) { - if (arg.dtype().is_handle()) { - alloc_storage_scope_[arg.get()] = "global"; - } - } -} - -CodeGenTileLangWebGPU::CodeGenTileLangWebGPU(Target target) : target_(target) {} - -runtime::FunctionInfo -CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) { - // clear previous generated state. - this->InitFuncState(f); - // reserve keywords - name_supply_->ReserveName("var"); - name_supply_->ReserveName("let"); - name_supply_->ReserveName("const"); - - // skip the first underscore, so SSA variable starts from - name_supply_->FreshName("v_"); - // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); - ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim"); - - // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc " - "to have the global_symbol attribute"; - - header_stream << "//----------------------------------------\n" - << "// Function: " << global_symbol.value() << "\n" - << "//----------------------------------------\n"; - runtime::FunctionInfo func_info; - func_info.name = global_symbol.value(); - - WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body); - - std::vector pod_args; - int num_buffer = 0; - - // add param_access modes info to launch params - std::ostringstream os_param_access; - os_param_access << "paramWriteAccess:["; - // setup buffer argumemts - for (Var arg : f->params) { - DataType t = arg.dtype(); - func_info.arg_types.push_back(t); - - if (t.is_handle()) { - auto *ptr = arg->type_annotation.as(); - ICHECK(ptr) << "All handles passed to the CodeGenTileLangWebGPU must " - "have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; - auto *prim = ptr->element_type.as(); - ICHECK(prim) << "All handles passed to the CodeGenTileLangWebGPU must " - "have a type_annotation as a " - "PointerType, " - << "and must point to a PrimType"; - DataType value_storage_type = prim->dtype; - if (value_storage_type == DataType::Bool()) { - // We need a physically addressable buffer type to support boolean - // tensors. The loaded byte is cast to bool inside the LoadNode visitor - // below. - value_storage_type = - boolean_storage_type_.with_lanes(value_storage_type.lanes()); - } - std::string vid = AllocVarID(arg.get()); - std::string access_mode; - if (num_buffer != 0) { - os_param_access << ","; - } - if (skip_readonly_decl || info.write_access_set.count(arg)) { - access_mode = "read_write"; - os_param_access << "1"; - } else { - access_mode = "read"; - os_param_access << "0"; - } - // add extra access mode info to launch params - this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " - << "var " << vid - << " : array<"; - this->PrintType(value_storage_type, this->decl_stream); - this->decl_stream << ">;\n"; - } else { - pod_args.push_back(arg); - } - } - - // Store all pod arguments in a single buffer of int32 - // do bitcast to change to other data types - // always pass gridDimX in to get around of the 65535 gridDim - // restrictions in some platforms - std::string type_pod_args = name_supply_->FreshName("PODArgs"); - std::string val_pod_args = name_supply_->FreshName("podArgs"); - std::string packGridDimX = name_supply_->FreshName("packGridDimX"); - - this->decl_stream << "\nstruct " << type_pod_args << " {\n"; - - for (size_t i = 0; i < pod_args.size(); ++i) { - const Var &v = pod_args[i]; - ICHECK(!v.dtype().is_handle()); - std::string vid = AllocVarID(v.get()); - - if (v.dtype() == DataType::Int(32)) { - this->decl_stream << " " << vid << ": i32"; - } else if (v.dtype() == DataType::UInt(32)) { - this->decl_stream << " " << vid << ": u32"; - } else if (v.dtype() == DataType::Float(32)) { - this->decl_stream << " " << vid << ": f32"; - } else { - LOG(FATAL) << "Do not support pod argument type " << v.dtype(); - } - this->decl_stream << ",\n"; - // value ref - std::ostringstream vref; - vref << val_pod_args << "." << vid; - var_idmap_[v.get()] = vref.str(); - } - this->decl_stream << " " << packGridDimX << ": u32\n}\n"; - - this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") " - << "var " << val_pod_args << " : " << type_pod_args - << ";\n\n"; - - // setup thread tags and param access in launch param tags; - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { - for (const auto &thread_tag : opt.value()) { - func_info.launch_param_tags.push_back(thread_tag); - } - } - os_param_access << "]"; - func_info.launch_param_tags.push_back(os_param_access.str()); - - ICHECK(!info.has_block_index_z) << "blockIdx.z is not supported in WebGPU to " - "accommodate large blockIdx.x"; - // annotate workgroup - this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", " - << info.workgroup_size[1] << ", " << info.workgroup_size[2] - << ")\n"; - - // add to alloc buffer type. - // Function header. - this->stream << "fn " << func_info.name << "(\n" - << " @builtin(workgroup_id) blockIdx : vec3,\n" - << " @builtin(num_workgroups) gridDim : vec3,\n" - << " @builtin(local_invocation_id) threadIdx : vec3\n" - << ") {\n"; - // skip out of bound grids - this->stream << " if (blockIdx.z * gridDim.x + blockIdx.x > " // NOLINT(*) - << val_pod_args << "." << packGridDimX << ") { return; }\n"; - // the function scope. - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; - return func_info; -} - -void CodeGenTileLangWebGPU::BindThreadIndex(const IterVar &iv) { - ICHECK(!var_idmap_.count(iv->var.get())); - std::ostringstream os; - PrintType(iv->var.dtype(), os); - if (iv->thread_tag == "blockIdx.x") { - // WebGPU have restriction to limit the maximum size of blockId.x to be - // 65535 We allow runtime to spread the load out to blockIdx.z so it can be - // a large number. - os << "(blockIdx.z * gridDim.x + blockIdx.x)"; - std::string tidx = os.str(); - std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype()); - var_idmap_[iv->var.get()] = aggregated_bidx; - } else { - os << "(" << iv->thread_tag << ")"; - std::string tidx = os.str(); - this->MarkConst(tidx); - var_idmap_[iv->var.get()] = tidx; - } -} - -void CodeGenTileLangWebGPU::PrintType(DataType t, - std::ostream &os) { // NOLINT(*) - int lanes = t.lanes(); - if (t.is_handle()) { - LOG(FATAL) << "Cannot print handle type in WebGPU"; - } - if (t.is_void()) { - os << "void"; - return; - } - if (t == DataType::Bool()) { - os << "bool"; - return; - } - - if (lanes != 1) { - // ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenTileLangWebGPU: only allows - // vector with lanes in {2, 3, 4} " << " while lanes is " << lanes; - os << "vec" << lanes << "<"; - } - - if (t.is_float()) { - ICHECK(t.bits() == 16 || t.bits() == 32) - << "CodeGenTileLangWebGPU: only support f16 or f32"; - if (t.bits() == 16) { - // Using f16 requires enable directive - enable_fp16_ = true; - } - os << "f" << t.bits(); - } else if (t.is_uint()) { - ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support u64"; - os << "u" << t.bits(); - } else if (t.is_int()) { - ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support i64"; - os << "i" << t.bits(); - } else { - LOG(FATAL) << "CodeGenTileLangWebGPU: Cannot convert type " << t - << " to WebGPU type"; - } - if (lanes != 1) { - os << ">"; - } -} - -void CodeGenTileLangWebGPU::PrintStorageSync(const CallNode *op) { - const std::string &sync = op->args[0].as()->value; - if (sync == "warp") { - this->PrintIndent(); - this->stream << "workgroupBarrier();\n"; - } else if (sync == "shared") { - this->PrintIndent(); - this->stream << "workgroupBarrier();\n"; - } else if (sync == "global") { - LOG(FATAL) << "global barrier not supported"; - } -} - -void CodeGenTileLangWebGPU::PrintSSAAssign(const std::string &target, - const std::string &src, - DataType type) { - stream << "let " << target << " : "; - PrintType(type, stream); - stream << " = " << src << ";\n"; -} - -void CodeGenTileLangWebGPU::VisitExpr_(const BroadcastNode *op, - std::ostream &os) { // NOLINT(*) - std::string v = PrintExpr(op->value); - int lanes = op->dtype.lanes(); - PrintType(op->dtype, os); - os << "("; - for (int i = 0; i < lanes; ++i) { - if (i != 0) - os << ", "; - os << v; - } - os << ')'; -} - -PrimExpr CodeGenTileLangWebGPU::EnforceU32(PrimExpr value) { - return cast(DataType::UInt(32, value.dtype().lanes()), value); -} - -void CodeGenTileLangWebGPU::VisitExpr_(const CallNode *op, - std::ostream &os) { // NOLINT(*) - if (op->op.same_as(builtin::reinterpret())) { - // generate bitcast(ARG) - os << "bitcast<"; - this->PrintType(op->dtype, os); - os << ">("; - this->PrintExpr(op->args[0], os); - os << ")"; - } else if (op->op.same_as(builtin::shift_right())) { - os << '('; - this->PrintExpr(op->args[0], os); - os << ">>"; - // WebGPU requires shift bits to be u32. - this->PrintExpr(EnforceU32(op->args[1]), os); - os << ')'; - } else if (op->op.same_as(builtin::shift_left())) { - os << '('; - this->PrintExpr(op->args[0], os); - os << "<<"; - // WebGPU requires shift bits to be u32. - this->PrintExpr(EnforceU32(op->args[1]), os); - os << ')'; - } else if (op->op.same_as(builtin::if_then_else())) { - // conditional that skips eval if cond evals to false - std::string result = name_supply_->FreshName("condval"); - std::string cond = PrintExpr(op->args[0]); - this->PrintIndent(); - this->stream << "var " << result << " : "; - PrintType(op->dtype, this->stream); - this->stream << ";\n"; - this->PrintIndent(); - this->stream << "if (" << cond << ") {\n"; - { - int then_scope = this->BeginScope(); - std::string true_val = PrintExpr(op->args[1]); - this->PrintIndent(); - this->stream << result << " = " << true_val << ";\n} else {\n"; - this->EndScope(then_scope); - } - { - int else_scope = this->BeginScope(); - std::string false_val = PrintExpr(op->args[2]); - this->PrintIndent(); - this->stream << result << " = " << false_val << ";\n}\n"; - this->EndScope(else_scope); - } - os << result; - } else { - CodeGenC::VisitExpr_(op, os); - } -} - -void CodeGenTileLangWebGPU::VisitExpr_(const CastNode *op, - std::ostream &os) { // NOLINT(*) - PrintType(op->dtype, os); - os << "(" << PrintExpr(op->value) << ")"; -} - -void CodeGenTileLangWebGPU::VisitExpr_(const SelectNode *op, - std::ostream &os) { // NOLINT(*) - os << "select(" << PrintExpr(op->false_value) << ", " - << PrintExpr(op->true_value) << ", " << PrintExpr(op->condition) << ")"; -} - -void CodeGenTileLangWebGPU::VisitExpr_(const IntImmNode *op, - std::ostream &os) { // NOLINT(*) - if (op->dtype.bits() == 32) { - std::ostringstream temp; - if (op->dtype.is_int()) { - temp << op->value << "i"; - } else { - ICHECK(op->dtype.is_uint()); - temp << op->value << "u"; - } - this->MarkConst(temp.str()); - os << temp.str(); - } else { - this->PrintType(op->dtype, os); - os << "(" << op->value << ")"; - } -} - -void CodeGenTileLangWebGPU::VisitExpr_(const FloatImmNode *op, - std::ostream &os) { // NOLINT(*) - std::ostringstream temp; - temp << std::scientific << op->value; - if (op->dtype.bits() == 32) { - temp << 'f'; - } else if (op->dtype.bits() == 16) { - // Using f16 requires enable directive - enable_fp16_ = true; - temp << 'h'; - } else { - LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits(); - } - MarkConst(temp.str()); - os << temp.str(); -} - -void CodeGenTileLangWebGPU::VisitExpr_(const BufferLoadNode *op, - std::ostream &os) { // NOLINT(*) - // NOTE: direct impl of load/store for correctness - // Each printing stmt must stand on their own after all preprocessing steps - // to ensure correctness in the case of nested-expression - // do not try to lift common printings from each case - ICHECK_EQ(op->indices.size(), 1) - << "Load from non-flat memory not supported."; - - DataType value_dtype = op->dtype; - PrimExpr index = op->indices[0]; - Var buffer_var = op->buffer->data; - DataType element_dtype = op->buffer->dtype; - - int lanes = op->dtype.lanes(); - std::string buffer_vid = GetVarID(buffer_var.get()); - - if (value_dtype.lanes() == element_dtype.lanes()) { - // Direct buffer loading - // Special handle bool loading - if (value_dtype == DataType::Bool()) { - this->PrintType(value_dtype, os); - os << "("; - } else { - ICHECK(value_dtype == element_dtype); - } - ICHECK_EQ(index.dtype().lanes(), 1); - os << buffer_vid << "[" << this->PrintExpr(index) << "]"; - // Special handle bool loading - if (value_dtype == DataType::Bool()) { - os << ")"; - } - } else { - // Vector load from scalar buffer - ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - ICHECK(value_dtype.element_of() == element_dtype) - << "WebGPU vector loading requires base type to match"; - arith::PVar base; - if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { - // vec3(buf[base + 0], buf[base + 1], buf[base + 2]); - std::string base_vid = - SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); - os << "("; - for (int i = 0; i < lanes; ++i) { - if (i != 0) - os << ", "; - os << buffer_vid << "[" << base_vid << " + " << i << "]"; - } - os << ")"; - } else { - // vec3(buf[index[0]], buf[index[1]], buf[index[2]]); - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - PrintType(element_dtype.with_lanes(value_dtype.lanes()), os); - os << "("; - for (int i = 0; i < lanes; ++i) { - if (i != 0) - os << ", "; - os << buffer_vid << "[" << index_vid << "[" << i << "]]"; - } - os << ")"; - } - } -} - -void CodeGenTileLangWebGPU::VisitStmt_(const LetStmtNode *op) { - // use ssa form. - if (print_ssa_form_) { - std::string value = PrintExpr(op->value); - ICHECK(!var_idmap_.count(op->var.get())); - var_idmap_[op->var.get()] = value; - } else { - PrintIndent(); - std::string value = PrintExpr(op->value); - this->stream << "let " << AllocVarID(op->var.get()) << " : "; - PrintType(op->var.dtype(), this->stream); - this->stream << " = " << value << ";\n"; - } - PrintStmt(op->body); -} - -void CodeGenTileLangWebGPU::VisitStmt_(const BufferStoreNode *op) { - CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported."; - DataType value_dtype = op->value.dtype(); - DataType element_dtype = op->buffer->dtype; - PrimExpr index = op->indices[0]; - Var buffer_var = op->buffer->data; - - std::string buffer_vid = GetVarID(buffer_var.get()); - - if (value_dtype.lanes() == element_dtype.lanes()) { - // must execute print expr first - // so we won't have recursive append to stream - std::string index_vid = PrintExpr(index); - std::string value_vid = PrintExpr(op->value); - // now print the assignment line. - this->PrintIndent(); - stream << buffer_vid << "[" << index_vid << "] = "; - // special explicit conversion of bool - if (value_dtype == DataType::Bool()) { - PrintType(element_dtype, stream); - stream << "("; - } else { - ICHECK(value_dtype == element_dtype); - } - stream << value_vid; - // Special handle bool store - if (value_dtype == DataType::Bool()) { - stream << ")"; - } - stream << ";\n"; - } else { - // Vector store into scalar buffer - ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array"; - ICHECK(value_dtype.element_of() == element_dtype) - << "WebGPU vector stire requires base type to match"; - std::string value_vid = PrintExpr(op->value); - arith::PVar base; - if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) { - // buf[base + 0] = value[0] - // buf[base + 1] = value[1] - std::string base_vid = - SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { - this->PrintIndent(); - stream << buffer_vid << "[" << base_vid << " + " << i - << "] = " << value_vid << "[" << i << "];\n"; - } - } else { - // buf[index[0]] = value[0] - // buf[index[1]] = value[1] - std::string index_vid = SSAGetID(PrintExpr(index), index.dtype()); - for (int i = 0; i < value_dtype.lanes(); ++i) { - this->PrintIndent(); - stream << buffer_vid << "[" << index_vid << "[" << i - << "]] = " << value_vid << "[" << i << "];\n"; - } - } - } -} - -void CodeGenTileLangWebGPU::VisitStmt_(const AllocateNode *op) { - ICHECK(!is_zero(op->condition)); - std::string vid = AllocVarID(op->buffer_var.get()); - size_t constant_size = op->ConstantAllocationSize(); - ICHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - auto storage_scope = - runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - - if (storage_scope.rank == runtime::StorageRank::kShared) { - this->decl_stream << "var " << vid << " : array<"; - PrintType(op->dtype, this->decl_stream); - this->decl_stream << ", " << constant_size << ">;\n"; - } else if (storage_scope.rank == runtime::StorageRank::kLocal) { - // TODO(Charlie): These code would cause non-uniformity as it introduces - // variables in module scope rather than function scope; but it was included - // for some unknown reasons; kept for now. this->decl_stream << - // "var " << vid << " : array<"; PrintType(op->dtype, - // this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n"; - this->PrintIndent(); - this->stream << "var " << vid << " : array<"; - PrintType(op->dtype, this->stream); - this->stream << ", " << constant_size << ">;\n"; - } else { - LOG(FATAL) << "WebGPU: Do not support storage scope: " - << storage_scope.to_string(); - } - this->PrintStmt(op->body); -} - -void CodeGenTileLangWebGPU::VisitStmt_(const ForNode *op) { - std::string extent = PrintExpr(op->extent); - std::string vid = AllocVarID(op->loop_var.get()); - ICHECK(is_zero(op->min)); - PrintIndent(); - stream << "for (var " << vid << " : "; - PrintType(op->loop_var.dtype(), stream); - stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n"; - int for_scope = BeginScope(); - PrintStmt(op->body); - this->EndScope(for_scope); - PrintIndent(); - stream << "}\n"; -} - -void CodeGenTileLangWebGPU::VisitStmt_(const AssertStmtNode *op) { - // skip assert - PrintStmt(op->body); -} - -void CodeGenTileLangWebGPU::VisitStmt_(const AllocateConstNode *op) { - LOG(FATAL) << "WebGPU: do not support alloc const"; -} - -void CodeGenTileLangWebGPU::VisitStmt_(const WhileNode *op) { - PrintIndent(); - stream << "while (true) {\n"; - int while_scope = BeginScope(); - std::string cond = PrintExpr(op->condition); - PrintIndent(); - stream << "if (!(" << cond << ")) { break; }\n"; - PrintStmt(op->body); - this->EndScope(while_scope); - PrintIndent(); - stream << "}\n"; -} - -//------------------------------------------------- -// WebGPUSourceModule to enable export -//------------------------------------------------- -class WebGPUSourceModuleNode final : public runtime::ModuleNode { -public: - explicit WebGPUSourceModuleNode( - std::unordered_map smap, - std::unordered_map fmap) - : smap_(smap), fmap_(fmap) {} - - const char *type_key() const final { return "webgpu"; } - /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const final { - return runtime::ModulePropertyMask::kBinarySerializable; - } - - ffi::Function GetFunction(const String &name, - const ObjectPtr &sptr_to_self) final { - LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run " - "through tvmjs"; - return ffi::Function(nullptr); - } - - void SaveToBinary(dmlc::Stream *stream) final { - stream->Write(fmap_); - stream->Write(smap_); - } - - String GetSource(const String &format) final { - if (format == "func_info") { - std::ostringstream stream; - dmlc::JSONWriter(&stream).Write(fmap_); - return stream.str(); - } else { - std::ostringstream os; - for (const auto &kv : smap_) { - os << kv.second; - } - return os.str(); - } - } - -private: - // function shader code table. - std::unordered_map smap_; - // function information table. - std::unordered_map fmap_; -}; - -//------------------------------------------------- -// Build logic. -//------------------------------------------------- -runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) { - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); - bool output_ssa = false; - bool skip_readonly_decl = false; - std::unordered_map smap; - std::unordered_map fmap; - - // narrow all i64 to i32 - mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod)); - - for (auto kv : mod->functions) { - CodeGenTileLangWebGPU cg(target); - ICHECK(kv.second->IsInstance()) - << "CodeGenTileLangWebGPU: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenTileLangWebGPU: expect calling_conv equals " - "CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc " - "to have the global_symbol attribute"; - std::string f_name = global_symbol.value(); - cg.Init(output_ssa); - fmap[f_name] = cg.AddFunction(f, skip_readonly_decl); - std::string code = cg.Finish(); - smap[f_name] = code; - } - - auto n = make_object(smap, fmap); - return runtime::Module(n); -} - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("target.build.tilelang_webgpu", - [](IRModule mod, Target target) { - return BuildTileLangWebGPU(mod, target); - }); -}); - -} // namespace codegen -} // namespace tvm diff --git a/src/target/codegen_webgpu.h b/src/target/codegen_webgpu.h deleted file mode 100644 index fa2da8895..000000000 --- a/src/target/codegen_webgpu.h +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file codegen_webgpu.h - * \brief Generate WebGPU shaders in WGSL. - * - * This module generates WGSL shading language. - * See https://www.w3.org/TR/WGSL/ for the language reference. - */ -#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ -#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ - -#include - -#include - -#include "target/source/codegen_c.h" - -namespace tvm { -namespace codegen { - -/*! - * \brief WebGPU code generator. - * - * Note WGSL have a different syntax from normal C. - * We only leverage the C for expression generation and - * write most of the language generations. - */ -class CodeGenTileLangWebGPU final : public CodeGenC { -public: - explicit CodeGenTileLangWebGPU(Target target); - // overrides - std::string Finish() final; - using CodeGenC::AddFunction; - runtime::FunctionInfo AddFunction(const PrimFunc &f, - bool skip_readonly_decl); // NOLINT(*) - void InitFuncState(const PrimFunc &f) final; - void PrintStorageSync(const CallNode *op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) - void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) - - // assignment printing - void PrintSSAAssign(const std::string &target, const std::string &src, - DataType type) final; - - // overload visitor - void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*) - void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*) - void VisitExpr_(const BufferLoadNode *op, - std::ostream &os) final; // NOLINT(*) - void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*) - void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*) - void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*) - - // stmt printing - void VisitStmt_(const LetStmtNode *op) final; - void VisitStmt_(const BufferStoreNode *op) final; - void VisitStmt_(const ForNode *op) final; - void VisitStmt_(const AllocateNode *op) final; - void VisitStmt_(const AssertStmtNode *op) final; - void VisitStmt_(const AllocateConstNode *op) final; - void VisitStmt_(const WhileNode *op) final; - -private: - /*! - * \brief Enforce value to be U32. - */ - static PrimExpr EnforceU32(PrimExpr value); - /*! - * \brief Storage type of bool values. - */ - DataType boolean_storage_type_{DataType::Int(8)}; - - // whether enable fp16 - bool enable_fp16_{false}; - - /*! \brief the header stream for function label and enable directive if any, - * goes before any other declaration */ - std::ostringstream header_stream; - - Target target_; -}; -} // namespace codegen -} // namespace tvm - -#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_ diff --git a/src/target/intrin_rule_cuda.cc b/src/target/intrin_rule_cuda.cc index 4ba3f10ab..1aacd7204 100644 --- a/src/target/intrin_rule_cuda.cc +++ b/src/target/intrin_rule_cuda.cc @@ -5,6 +5,7 @@ #include #include +#include "../support/ffi_aliases.h" #include "target/intrin_rule.h" namespace tvm { diff --git a/src/target/intrin_rule_hip.cc b/src/target/intrin_rule_hip.cc index 2bd3e2dd9..e142d8474 100644 --- a/src/target/intrin_rule_hip.cc +++ b/src/target/intrin_rule_hip.cc @@ -5,6 +5,7 @@ #include #include +#include "../support/ffi_aliases.h" #include "target/intrin_rule.h" namespace tvm { @@ -286,4 +287,4 @@ TVM_REGISTER_OP("tir.hip.__activemask") } // namespace intrin } // namespace codegen -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/target/rt_mod_cpp.cc b/src/target/rt_mod_cpp.cc index a7f2e62b9..10e3d57b6 100644 --- a/src/target/rt_mod_cpp.cc +++ b/src/target/rt_mod_cpp.cc @@ -1,10 +1,13 @@ #include "codegen_cpp.h" +#include #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace codegen { -runtime::Module BuildCPPHost(IRModule mod, Target target) { +ffi::Module BuildCPPHost(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; bool emit_fwd_func_decl = true; @@ -67,10 +70,10 @@ runtime::Module BuildCPPHost(IRModule mod, Target target) { return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.tilelang_cpp", BuildCPPHost); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/rt_mod_cuda.cc b/src/target/rt_mod_cuda.cc index 63a9f020b..bb69170fe 100644 --- a/src/target/rt_mod_cuda.cc +++ b/src/target/rt_mod_cuda.cc @@ -26,18 +26,19 @@ ExtractFuncInfo(const IRModule &mod) { } info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { info.launch_param_tags.push_back(tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; } -runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { +ffi::Module BuildTileLangCUDA(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangCUDA cg; cg.Init(output_ssa); @@ -70,7 +71,7 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { +ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangCUDA cg; cg.Init(output_ssa); @@ -93,13 +94,13 @@ runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.tilelang_cuda", BuildTileLangCUDA) .def("target.build.tilelang_cuda_without_compile", BuildTileLangCUDAWithoutCompile); -}); +} } // namespace codegen } // namespace tvm diff --git a/src/target/rt_mod_hip.cc b/src/target/rt_mod_hip.cc index d0041f570..50991d631 100644 --- a/src/target/rt_mod_hip.cc +++ b/src/target/rt_mod_hip.cc @@ -37,18 +37,19 @@ ExtractFuncInfo(const IRModule &mod) { } info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>( + tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { info.launch_param_tags.push_back(tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; } -runtime::Module BuildTileLangHIP(IRModule mod, Target target) { +ffi::Module BuildTileLangHIP(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); @@ -84,7 +85,7 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); } -runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { +ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); @@ -110,13 +111,13 @@ runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { std::string()); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.tilelang_hip", BuildTileLangHIP) .def("target.build.tilelang_hip_without_compile", BuildTileLangHIPWithoutCompile); -}); +} } // namespace codegen -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/target/utils.cc b/src/target/utils.cc index ca4f8570b..b69e3dd4c 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -5,6 +5,9 @@ #include "utils.h" +#include "../support/ffi_aliases.h" +#include + namespace tvm { namespace tl { @@ -16,8 +19,8 @@ bool TargetIsRocm(Target target) { } int GetArchInt(Target target) { - auto s = target->GetAttr("arch"); - ICHECK(s.defined()); + auto s = target->GetAttr("arch"); + ICHECK(s.has_value()); const std::string arch_str = s.value(); ICHECK(arch_str.size() >= 3); ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0) @@ -71,7 +74,7 @@ bool TargetIsCDNA(Target target) { if (!TargetIsRocm(target)) return false; if (target->attrs.count("mcpu")) { - std::string mcpu = Downcast(target->attrs.at("mcpu")); + std::string mcpu = Downcast(target->attrs.at("mcpu")); // if mcpu start with "gfx9", it is CDNA return mcpu.find("gfx9") == 0; } @@ -84,7 +87,7 @@ bool TargetHasAsyncCopy(Target target) { return arch >= 80; } else if (TargetIsCDNA(target)) { if (target->attrs.count("mcpu")) { - std::string mcpu = Downcast(target->attrs.at("mcpu")); + std::string mcpu = Downcast(target->attrs.at("mcpu")); if (mcpu.rfind("gfx9", 0) == 0) { int gfx_version = std::stoi(mcpu.substr(3, 2)); return gfx_version >= 94; @@ -131,7 +134,7 @@ int TargetGetWarpSize(Target target) { return res; } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("tl.TargetIsCuda", @@ -160,7 +163,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ [](Target target) { return TargetHasBulkCopy(target); }) .def("tl.TargetGetWarpSize", [](Target target) { return TargetGetWarpSize(target); }); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/align_dynamic_shared_memory_allocations.cc b/src/transform/align_dynamic_shared_memory_allocations.cc index 27890c445..1c2519df9 100644 --- a/src/transform/align_dynamic_shared_memory_allocations.cc +++ b/src/transform/align_dynamic_shared_memory_allocations.cc @@ -47,7 +47,7 @@ class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { } Stmt VisitStmt_(const BlockNode *op) final { - Block block = GetRef(op); + Block block = tvm::ffi::GetRef(op); Array alloc_buffers = op->alloc_buffers; alloc_buffers.MutateByApply([this](Buffer buf) { auto storage_scope = @@ -58,7 +58,7 @@ class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { buf->dtype.bytes()); if (!new_shape.same_as(buf->shape)) { ObjectPtr new_buffer = - make_object(*(buf.get())); + tvm::ffi::make_object(*(buf.get())); new_buffer->shape = std::move(new_shape); buffer_remap_.Set(buf, Buffer(new_buffer)); return Buffer(new_buffer); @@ -73,7 +73,7 @@ class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode *op) final { - auto store_node = GetRef(op); + auto store_node = tvm::ffi::GetRef(op); Buffer buf = op->buffer; if (buffer_remap_.count(buf)) { buf = buffer_remap_[buf]; @@ -83,7 +83,7 @@ class TileLangAlignDynamicSharedMemoryAllocations : public StmtExprMutator { } PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto load_node = GetRef(op); + auto load_node = tvm::ffi::GetRef(op); Buffer buf = op->buffer; if (buffer_remap_.count(buf)) { buf = buffer_remap_[buf]; @@ -149,11 +149,11 @@ tvm::transform::Pass AlignDynamicSharedMemoryAllocations(int align_bytes) { "tl.AlignDynamicSharedMemoryAllocations", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.AlignDynamicSharedMemoryAllocations", AlignDynamicSharedMemoryAllocations); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/annotate_device_regions.cc b/src/transform/annotate_device_regions.cc index ed57f3729..ecc0cba9d 100644 --- a/src/transform/annotate_device_regions.cc +++ b/src/transform/annotate_device_regions.cc @@ -46,13 +46,13 @@ class DeviceRegionAnnotater : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tvm::attr::kTarget) { // If a target attribute already exists, use it as-is. - return GetRef(op); + return tvm::ffi::GetRef(op); } else if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::pipeline_exec_scope || op->attr_key == tir::attr::device_scope) { // These attributes are only allowed in device-side code, so // they should be annotated with the function's default target. - Stmt body = GetRef(op); + Stmt body = tvm::ffi::GetRef(op); return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); } else { // All other annotations are ignored @@ -90,11 +90,11 @@ tvm::transform::Pass AnnotateDeviceRegions() { return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateDeviceRegions", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index 6949c64e8..537c229a2 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -181,11 +181,11 @@ tvm::transform::Pass AnnotateWarpGroupRegAlloc() { return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc", AnnotateWarpGroupRegAlloc); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/arg_binder.cc b/src/transform/arg_binder.cc index 2caef2239..7df6d0cc8 100644 --- a/src/transform/arg_binder.cc +++ b/src/transform/arg_binder.cc @@ -80,8 +80,8 @@ void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value, Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array &arg, - const Array &value, +void ArgBinder::BindArray(const ffi::Array &arg, + const ffi::Array &value, const std::string &arg_name) { ICHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; @@ -250,7 +250,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); PrimExpr expect_stride = make_const(stype, 1); - Array conds; + ffi::Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; PrimExpr svalue = diff --git a/src/transform/arg_binder.h b/src/transform/arg_binder.h index d2dcc06aa..d04e7e9b2 100644 --- a/src/transform/arg_binder.h +++ b/src/transform/arg_binder.h @@ -82,7 +82,8 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array &arg, const Array &value, + void BindArray(const ffi::Array &arg, + const ffi::Array &value, const std::string &arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer @@ -149,7 +150,7 @@ class ArgBinder { */ const std::vector &init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map &def_handle_dtype() const { + const ffi::Map &def_handle_dtype() const { return def_handle_dtype_; } @@ -164,7 +165,7 @@ class ArgBinder { /*! \brief Initialize nest */ std::vector init_nest_; /*! \brief handle data type in the defintiions */ - Map def_handle_dtype_; + ffi::Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; /*! \brief internal analyzer. */ diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index cd63c9583..40cb81402 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -249,7 +249,6 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { new_args.push_back(dst_node); new_args.push_back(value_node); } - new_args.push_back(memory_order); Call new_call = @@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) { } } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/transform/cluster_planning.cc b/src/transform/cluster_planning.cc index e847bb2b6..7fcdc1691 100644 --- a/src/transform/cluster_planning.cc +++ b/src/transform/cluster_planning.cc @@ -10,6 +10,8 @@ #include #include +#include "../support/ffi_aliases.h" + namespace tvm { namespace tir { @@ -66,7 +68,8 @@ class ClusterPlanner { } if (mem_reuse_max > 0) { - std::string tag_str = cluster_tag; // Convert to std::string + std::string tag_str = + static_cast(cluster_tag); // Convert to std::string if (tag_str.rfind("blockIdx", 0) == 0) { // starts with "blockIdx" tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx")); @@ -74,7 +77,7 @@ class ClusterPlanner { // Unexpected format — maybe just prefix tag_str = "clusterIdx" + tag_str; } - cluster_tag = tvm::ffi::String(tag_str); // Convert back + cluster_tag = String(tag_str); // Convert back return WithAttr(f, cluster_tag, Integer(cluster_size_)); } else { return f; @@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() { return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning); -}); +} } // namespace transform } // namespace tir diff --git a/src/transform/common/loop_parallel_transform_utils.h b/src/transform/common/loop_parallel_transform_utils.h index b5a1ccddc..1e8d7a350 100644 --- a/src/transform/common/loop_parallel_transform_utils.h +++ b/src/transform/common/loop_parallel_transform_utils.h @@ -41,7 +41,7 @@ class ParallelLoopTransformer : public IRMutatorWithAnalyzer { return StmtMutator::VisitStmt_(op); // Collect loop variables and ranges - auto for_node = GetRef(op); + auto for_node = tvm::ffi::GetRef(op); Array loop_vars; Array loop_extents; Stmt body = op->body; @@ -81,7 +81,7 @@ class ParallelLoopTransformer : public IRMutatorWithAnalyzer { // post order visit the index PostOrderVisit(index, [&](const ObjectRef &obj) { if (const VarNode *v = obj.as()) { - used_vars.insert(GetRef(v)); + used_vars.insert(tvm::ffi::GetRef(v)); } }); if (used_vars.empty()) { diff --git a/src/transform/common/loop_vectorization_utils.h b/src/transform/common/loop_vectorization_utils.h index 3f033c966..b9b7715d0 100644 --- a/src/transform/common/loop_vectorization_utils.h +++ b/src/transform/common/loop_vectorization_utils.h @@ -211,7 +211,7 @@ class Vectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); @@ -265,7 +265,7 @@ class Vectorizer : public StmtMutator, PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return !(a); } @@ -306,10 +306,10 @@ class Vectorizer : public StmtMutator, PrimExpr value = this->VisitExpr(op->value); if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } if (value.same_as(op->value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Broadcast(op->value, op->lanes); } @@ -321,7 +321,7 @@ class Vectorizer : public StmtMutator, PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor(); @@ -339,7 +339,7 @@ class Vectorizer : public StmtMutator, PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { if (value.dtype().is_scalable_vector()) { return Cast(op->dtype.with_scalable_vscale_factor( @@ -352,20 +352,20 @@ class Vectorizer : public StmtMutator, } PrimExpr VisitExpr_(const FloatImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const IntImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const StringImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } // Variable PrimExpr VisitExpr_(const VarNode *op) final { - Var var = GetRef(op); + Var var = tvm::ffi::GetRef(op); if (var.same_as(var_)) { return ramp_; @@ -382,13 +382,13 @@ class Vectorizer : public StmtMutator, PrimExpr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); @@ -410,7 +410,7 @@ class Vectorizer : public StmtMutator, ICHECK(op->op.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int lanes = value.dtype().get_lanes_or_vscale_factor(); if (value.dtype().is_scalable_vector()) { @@ -455,12 +455,12 @@ class Vectorizer : public StmtMutator, auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Call(op->dtype, op->op, new_args); } @@ -469,7 +469,7 @@ class Vectorizer : public StmtMutator, Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Call(op->dtype.with_lanes(lane), op->op, new_args); } @@ -477,7 +477,7 @@ class Vectorizer : public StmtMutator, } // BufferLoad PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto load = GetRef(op); + auto load = tvm::ffi::GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); @@ -514,7 +514,7 @@ class Vectorizer : public StmtMutator, let_binding_[op->var] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -522,7 +522,7 @@ class Vectorizer : public StmtMutator, } // BufferStore Stmt VisitStmt_(const BufferStoreNode *op) final { - auto store = GetRef(op); + auto store = tvm::ffi::GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); @@ -585,11 +585,11 @@ class Vectorizer : public StmtMutator, ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, op->annotations); @@ -600,7 +600,7 @@ class Vectorizer : public StmtMutator, ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = std::nullopt; @@ -609,7 +609,7 @@ class Vectorizer : public StmtMutator, } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -634,7 +634,7 @@ class Vectorizer : public StmtMutator, let_binding_[op->var] = op->var; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -647,7 +647,7 @@ class Vectorizer : public StmtMutator, if (condition.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } // Mutate the extents @@ -657,7 +657,7 @@ class Vectorizer : public StmtMutator, if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } extents.push_back(new_ext); } @@ -738,7 +738,7 @@ class Vectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -754,7 +754,7 @@ class Vectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); diff --git a/src/transform/config_index_bitwidth.cc b/src/transform/config_index_bitwidth.cc index 58ca0da7f..b0a577555 100644 --- a/src/transform/config_index_bitwidth.cc +++ b/src/transform/config_index_bitwidth.cc @@ -38,7 +38,7 @@ class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter { if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) { return IntImm(DataType::Int(_index_bitwidth_), op->value); } - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const CastNode *op) final { @@ -88,23 +88,23 @@ class IndexLegalizer : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const VarNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { - return cast(DataType::Int(64), GetRef(op)); + return cast(DataType::Int(64), tvm::ffi::GetRef(op)); } - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const IntImmNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { return IntImm(DataType::Int(64), op->value); } - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const CastNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { return cast(DataType::Int(64), op->value); } - return GetRef(op); + return tvm::ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode *op) final { @@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() { return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth", ConfigIndexBitwidth); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/eliminate_storage_sync_for_mbarrier.cc b/src/transform/eliminate_storage_sync_for_mbarrier.cc index cc187e8e2..504de732c 100644 --- a/src/transform/eliminate_storage_sync_for_mbarrier.cc +++ b/src/transform/eliminate_storage_sync_for_mbarrier.cc @@ -35,9 +35,7 @@ class Eliminator : public IRMutatorWithAnalyzer { Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == "thread_extent") { - const VarNode *var = nullptr; - if (op->node->IsInstance()) { - var = op->node.as(); + if (const auto *var = op->node.as()) { if (var->name_hint == "threadIdx.x") { thread_extent_ = op; } @@ -82,7 +80,7 @@ class Eliminator : public IRMutatorWithAnalyzer { } Stmt VisitStmt_(const ForNode *op) final { - PostOrderVisit(GetRef(op), [&](const ObjectRef &node) { + PostOrderVisit(tvm::ffi::GetRef(op), [&](const ObjectRef &node) { if (const auto *call = node.as()) { if (call->op.same_as(create_list_of_mbarrier()) || call->op.same_as(mbarrier_wait_parity()) || @@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() { {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier", EliminateStorageSyncForMBarrier); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/flatten_buffer.cc b/src/transform/flatten_buffer.cc index 4affa5f6e..3b68d3373 100644 --- a/src/transform/flatten_buffer.cc +++ b/src/transform/flatten_buffer.cc @@ -75,23 +75,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const VarNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { - return cast(DataType::Int(64), GetRef(op)); + return cast(DataType::Int(64), tvm::ffi::GetRef(op)); } - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const IntImmNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { return IntImm(DataType::Int(64), op->value); } - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const CastNode *op) final { if (op->dtype.is_int() && op->dtype.bits() < 64) { return cast(DataType::Int(64), op->value); } - return GetRef(op); + return tvm::ffi::GetRef(op); } Stmt VisitStmt_(const BufferStoreNode *op) final { @@ -115,7 +115,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { << "All MatchBufferRegion should be removed in " "tir.transform.LowerMatchBuffer."; - Block block = GetRef(op); + Block block = tvm::ffi::GetRef(op); Array alloc_buffers = op->alloc_buffers; alloc_buffers.MutateByApply( @@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() { return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/frontend_legalize.cc b/src/transform/frontend_legalize.cc index b366d02d1..ffb4b1a53 100644 --- a/src/transform/frontend_legalize.cc +++ b/src/transform/frontend_legalize.cc @@ -89,10 +89,10 @@ Pass LetInline() { return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LetInline", LetInline); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/if_stmt_binding.cc b/src/transform/if_stmt_binding.cc index 5eb8c1181..5da796c9d 100644 --- a/src/transform/if_stmt_binding.cc +++ b/src/transform/if_stmt_binding.cc @@ -33,7 +33,7 @@ class IfStmtBindingRewriter : public StmtExprMutator { auto then_case = VisitStmt(op->then_case); Optional else_case = op->else_case; if (else_case.defined()) { - return GetRef(op); + return tvm::ffi::GetRef(op); } ICHECK(then_case.defined()) << "then_case must be defined"; ICHECK(!else_case.defined()) << "else_case must be undefined"; @@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() { return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/inject_assumes.cc b/src/transform/inject_assumes.cc index d4c8a53c8..485e270c3 100644 --- a/src/transform/inject_assumes.cc +++ b/src/transform/inject_assumes.cc @@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() { return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes); -}); +} } // namespace tvm::tl diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index ee76dfac1..f425d4a9e 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -319,10 +319,10 @@ tvm::transform::Pass InjectFenceProxy() { {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 20f0861e2..3bb13611d 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -37,7 +37,7 @@ namespace tvm { namespace tl { using namespace tir; - +using namespace ffi; namespace software_pipeline { /*! @@ -459,7 +459,8 @@ class PipelineRewriter : public StmtExprMutator { * \return The resized buffer. */ Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { - ObjectPtr new_buffer = make_object(*(buffer.get())); + ObjectPtr new_buffer = + tvm::ffi::make_object(*(buffer.get())); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (!new_buffer->strides.empty()) { ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); @@ -865,7 +866,7 @@ class PipelineInjector : private StmtExprMutator { const SeqStmtNode *pipeline_body_seq = nullptr; std::vector> rewrap_fns; auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) { - ObjectRef node = attr->node; + Any node = attr->node; String attr_key = attr->attr_key; PrimExpr value = attr->value; Span span = attr->span; @@ -981,7 +982,7 @@ class PipelineInjector : private StmtExprMutator { // Step 4: Rewrite the pipeline body. Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, - GetRef(op), pipeline_info) + tvm::ffi::GetRef(op), pipeline_info) .BuildPipeline(); auto apply_wrappers = [&](Stmt stmt) { for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { @@ -1072,11 +1073,11 @@ tir::transform::Pass InjectSoftwarePipeline() { return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline", InjectSoftwarePipeline); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/inject_ptx_async_copy.cc b/src/transform/inject_ptx_async_copy.cc index 5b3ad4226..1fadefbf4 100644 --- a/src/transform/inject_ptx_async_copy.cc +++ b/src/transform/inject_ptx_async_copy.cc @@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() { return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index 39c6debda..aad1f474b 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -204,9 +204,9 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer { void VisitStmt_(const EvaluateNode *op) final { if (const auto *call = op->value.as()) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { - pending_tma_ops_.push_back(GetRef(call)); + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); } else if (call->op.same_as(mbarrier_expect_tx())) { - pending_tma_ops_.push_back(GetRef(call)); + pending_tma_ops_.push_back(tvm::ffi::GetRef(call)); } else if (call->op.same_as(builtin::ptx_arrive_barrier())) { PrimExpr barrier_id = call->args[0]; for (const auto &tma_call : pending_tma_ops_) { @@ -295,8 +295,9 @@ class TmaSequenceCollector : public IRVisitorWithAnalyzer { void VisitExpr_(const CallNode *op) final { if (op->op.same_as(mbarrier_expect_tx())) { - PrimExpr e = - tma_op_to_barrier_id_[GetRef(op)].as()->args[0]; + PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)] + .as() + ->args[0]; auto int_set = arith::EvalSet(e, var_int_set_); expect_.push_back(if_depth_ == 1); sequence.push_back(0); @@ -406,7 +407,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { private: Stmt VisitStmt_(const BlockNode *op) { - auto block = GetRef(op); + auto block = tvm::ffi::GetRef(op); if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() && op->name_hint == MainBlockName) { ICHECK(false) << "Please declare create_list_of_mbarrier."; @@ -453,9 +454,9 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode *op) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { // check this must be in the tma_op_to_barrier_id_ - ICHECK(tma_op_to_barrier_id_.count(GetRef(op))) + ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef(op))) << "tma_load must be in the tma_op_to_barrier_id_"; - auto barrier_id = tma_op_to_barrier_id_[GetRef(op)]; + auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)]; auto new_args = op->args; auto arg0 = op->args[0].as(); auto is_1d_tma_load = @@ -468,9 +469,9 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { } return Call(op->dtype, op->op, new_args); } else if (op->op.same_as(mbarrier_expect_tx())) { - ICHECK(tma_op_to_barrier_id_.count(GetRef(op))) + ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef(op))) << "mbarrier_expect_tx must be in the tma_op_to_barrier_id_"; - auto barrier_id = tma_op_to_barrier_id_[GetRef(op)]; + auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef(op)]; auto new_args = op->args; new_args.Set(0, barrier_id); if (!has_warp_specialization_) @@ -522,10 +523,10 @@ tvm::transform::Pass InjectTmaBarrier() { return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index c3e552538..282a8d7de 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -330,7 +330,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (op->op.as()) return; - auto p = ParseOperator(GetRef(op), buffer_data_to_buffer_); + auto p = ParseOperator(tvm::ffi::GetRef(op), buffer_data_to_buffer_); if (p.defined()) { for (const auto &arg : op->args) { if (auto buffer = getBufferFromAccessPtr(arg)) { @@ -381,7 +381,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } // Add the tile operator to infer_list_ - infer_list_stmt_.push_back(GetRef(op)); + infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); infer_list_.push_back(std::move(p)); } } @@ -416,11 +416,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void VisitStmt_(const ForNode *op) final { if (op->kind == ForKind::kParallel) { - auto infer = ParallelOp(GetRef(op)); + auto infer = ParallelOp(tvm::ffi::GetRef(op)); for (const auto &[buffer, _] : infer->GetIndiceMap()) { addToUseList(buffer); } - infer_list_stmt_.push_back(GetRef(op)); + infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); infer_list_.push_back(std::move(infer)); thread_var_vec_.push_back(thread_var_); if (thread_var_.defined() && @@ -711,8 +711,8 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { .value(); For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - if (result_.for_map.count(GetRef(op))) { - auto root = GetRef(op); + if (result_.for_map.count(tvm::ffi::GetRef(op))) { + auto root = tvm::ffi::GetRef(op); // This check is a workaround to support T.Parallel for local buffers. // For example: // for i in T.Parallel(1024): @@ -831,10 +831,10 @@ tvm::transform::Pass LayoutInference() { return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/layout_reducer.cc b/src/transform/layout_reducer.cc index e875c972c..101e9f4a1 100644 --- a/src/transform/layout_reducer.cc +++ b/src/transform/layout_reducer.cc @@ -362,10 +362,10 @@ tvm::transform::Pass LayoutReducer() { return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/layout_reducer.h b/src/transform/layout_reducer.h index 894631cc2..e46ade948 100644 --- a/src/transform/layout_reducer.h +++ b/src/transform/layout_reducer.h @@ -66,17 +66,17 @@ struct ReducerInfoNode : Object { ReducerInfoNode() = default; ReducerInfoNode(const String &op_str, const String &rep_str); - static constexpr const char *_type_key = "tl.ReducerInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object); }; struct ReducerInfo : ObjectRef { public: TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) { - data_ = make_object(op_str, rep_str); + data_ = tvm::ffi::make_object(op_str, rep_str); } - TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef, + ReducerInfoNode); }; namespace attr { diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index ee408d4a5..68a0cdbb8 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -38,7 +38,7 @@ class LeafForFinder : public StmtVisitor { StmtVisitor::VisitStmt(op->body); if (!has_child_for_) { - leaf_for_nodes.push_back(GetRef(op)); + leaf_for_nodes.push_back(tvm::ffi::GetRef(op)); } parent_has_child_for_ = parent_has_child_for; @@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { } // Register the pass globally so it can be used in the compilation pipeline -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess", LegalizeSafeMemoryAccess); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/legalize_vectorized_loop.cc b/src/transform/legalize_vectorized_loop.cc index dc2099208..aa461784a 100644 --- a/src/transform/legalize_vectorized_loop.cc +++ b/src/transform/legalize_vectorized_loop.cc @@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() { } // Register the pass globally so it can be used in the compilation pipeline -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop", LegalizeVectorizedLoop); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index e9930310a..fe1fe0366 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -173,7 +173,7 @@ class LoopPramaUnroller : public StmtExprMutator { if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) { return StmtExprMutator::VisitStmt_(node); } - For new_for = GetRef(node); + For new_for = tvm::ffi::GetRef(node); auto for_ptr = new_for.CopyOnWrite(); for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false)); for_ptr->kind = ForKind::kUnrolled; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 4550af8e4..45283d905 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } bool CanProveIndependent(const PrimExpr &expr, Var var, arith::Analyzer *analyzer) { // 1. if var doesn't exist, it is independent - bool used_var = UsesVar( - expr, [&](const VarNode *v) { return GetRef(v).same_as(var); }); + bool used_var = UsesVar(expr, [&](const VarNode *v) { + return tvm::ffi::GetRef(v).same_as(var); + }); if (!used_var) { return true; } diff --git a/src/transform/loop_vectorize_dynamic.cc b/src/transform/loop_vectorize_dynamic.cc index d02582726..c72af5a07 100644 --- a/src/transform/loop_vectorize_dynamic.cc +++ b/src/transform/loop_vectorize_dynamic.cc @@ -231,10 +231,10 @@ class VectorizedBodyMutator : public StmtExprMutator { if (flag) { return thenexpr; } else { - return GetRef(op); + return tvm::ffi::GetRef(op); } } else { - return GetRef(op); + return tvm::ffi::GetRef(op); } } @@ -535,11 +535,11 @@ tvm::transform::Pass LoopVectorizeDynamic() { } // Register the pass globally so it can be used in the compilation pipeline -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic", LoopVectorizeDynamic); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/lower_device_kernel_launch.cc b/src/transform/lower_device_kernel_launch.cc index 7ea7f7c62..f2d8ae239 100644 --- a/src/transform/lower_device_kernel_launch.cc +++ b/src/transform/lower_device_kernel_launch.cc @@ -36,7 +36,7 @@ namespace tvm { namespace tl { using namespace tir; - +using namespace ffi; namespace { struct KernelInfo { // The device on which the PrimFunc runs @@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto &[gvar, base_func] : mod->functions) { if (auto *ptr = base_func.as()) { - auto prim_func = - mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); + auto prim_func = mutator.RewriteKernelLaunchSite( + gvar, tvm::ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } @@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { IRModule updates; for (const auto &[gvar, base_func] : mod->functions) { if (auto *ptr = base_func.as()) { - auto prim_func = - mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); + auto prim_func = mutator.UpdateKernelAttributes( + gvar, tvm::ffi::GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } @@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() { "tl.LowerDeviceKernelLaunch", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/lower_device_storage_access_info.cc b/src/transform/lower_device_storage_access_info.cc index 635a3fdb8..1be06af27 100644 --- a/src/transform/lower_device_storage_access_info.cc +++ b/src/transform/lower_device_storage_access_info.cc @@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() { {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo", LowerDeviceStorageAccessInfo); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/lower_hopper_intrin.cc b/src/transform/lower_hopper_intrin.cc index 6e0da6993..b082a574e 100644 --- a/src/transform/lower_hopper_intrin.cc +++ b/src/transform/lower_hopper_intrin.cc @@ -113,14 +113,14 @@ class LowerHopperIntrin : public StmtExprMutator { if (call->op.same_as(create_tma_descriptor()) || call->op.same_as(create_tma_im2col_descriptor())) { Var var; - auto iter = desc_map_.find(GetRef(call)); + auto iter = desc_map_.find(tvm::ffi::GetRef(call)); if (iter != desc_map_.end()) { var = iter->second; } else { String name = call->args[2].as().value()->name_hint; var = Var(name + "_desc", PointerType(PrimType(cuTensorMapType()), "grid_constant")); - desc_map_[GetRef(call)] = var; + desc_map_[tvm::ffi::GetRef(call)] = var; prefetch_calls_.push_back( Evaluate(Call(DataType::Handle(), builtin::call_extern(), {StringImm("tl::prefetch_tma_descriptor"), var}))); @@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin); -}); +} #endif // (CUDA_MAJOR_VERSION >= 12) } // namespace tl diff --git a/src/transform/lower_intrin.cc b/src/transform/lower_intrin.cc index 737fc8936..edd0e1a18 100644 --- a/src/transform/lower_intrin.cc +++ b/src/transform/lower_intrin.cc @@ -37,6 +37,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: @@ -70,9 +71,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode *op) final { if (auto *ptr_op = op->op.as()) { for (const auto &f_attr_map : attr_maps_) { - FLowerGeneral f = f_attr_map.get(GetRef(ptr_op), nullptr); + FLowerGeneral f = f_attr_map.get(tvm::ffi::GetRef(ptr_op), nullptr); if (f != nullptr) { - PrimExpr e = GetRef(op); + PrimExpr e = tvm::ffi::GetRef(op); PrimExpr r = f(e); ICHECK(r.defined()) << "intrinsic rule must always return valid Expr"; if (!r.same_as(e)) { @@ -99,7 +100,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions PrimExpr VisitExpr_(const FloorDivNode *op) final { - auto e = GetRef(op); + auto e = tvm::ffi::GetRef(op); PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) @@ -305,7 +306,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using namespace arith; PVar x, y; PVar c; - auto e = GetRef(op); + auto e = tvm::ffi::GetRef(op); if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); @@ -316,7 +317,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const EQNode *op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = tvm::ffi::GetRef(op); if ((floormod(x, y) == 0).Match(e)) { return VisitExpr((truncmod(x, y) == 0).Eval()); } @@ -326,7 +327,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const NENode *op) final { using namespace arith; PVar x, y; - auto e = GetRef(op); + auto e = tvm::ffi::GetRef(op); if ((floormod(x, y) != 0).Match(e)) { return VisitExpr((truncmod(x, y) != 0).Eval()); } @@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin); -}); +} } // namespace transform diff --git a/src/transform/lower_l2_persistent_annotation.cc b/src/transform/lower_l2_persistent_annotation.cc index 8a8dee4c0..1f7be710d 100644 --- a/src/transform/lower_l2_persistent_annotation.cc +++ b/src/transform/lower_l2_persistent_annotation.cc @@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/lower_opaque_block.cc b/src/transform/lower_opaque_block.cc index b278fbf47..aa2e63850 100644 --- a/src/transform/lower_opaque_block.cc +++ b/src/transform/lower_opaque_block.cc @@ -151,7 +151,7 @@ class OpaqueBlockLower : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode *op) final { - Var var = GetRef(op); + Var var = tvm::ffi::GetRef(op); auto it = unit_loop_vars_.find(var); if (it == unit_loop_vars_.end()) { return var; @@ -286,10 +286,10 @@ tir::transform::Pass LowerOpaqueBlock() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/lower_shared_barrier.cc b/src/transform/lower_shared_barrier.cc index a3208d181..991676cb8 100644 --- a/src/transform/lower_shared_barrier.cc +++ b/src/transform/lower_shared_barrier.cc @@ -32,7 +32,7 @@ class SharedBarrierRewriter : public StmtExprMutator { : disable_shuffle_elect_(disable_shuffle_elect) {} Stmt VisitStmt_(const BlockNode *op) final { - Block block = GetRef(op); + Block block = tvm::ffi::GetRef(op); Array alloc_buffers = op->alloc_buffers; // Record the mapping from buffer data var to buffer for later lookup @@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/lower_shared_tmem.cc b/src/transform/lower_shared_tmem.cc index 661b39949..191ca700e 100644 --- a/src/transform/lower_shared_tmem.cc +++ b/src/transform/lower_shared_tmem.cc @@ -30,7 +30,7 @@ class SharedTmemRewriter : public StmtExprMutator { private: Stmt VisitStmt_(const BlockNode *op) final { - Block block = GetRef(op); + Block block = tvm::ffi::GetRef(op); Array alloc_buffers = op->alloc_buffers; if (op->annotations.count(attr::kLayoutMap)) { auto layout_map = op->annotations.Get(attr::kLayoutMap); @@ -300,10 +300,10 @@ tvm::transform::Pass LowerSharedTmem() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/lower_thread_allreduce.cc b/src/transform/lower_thread_allreduce.cc index 71ef8a92c..dc0fbeb85 100644 --- a/src/transform/lower_thread_allreduce.cc +++ b/src/transform/lower_thread_allreduce.cc @@ -39,6 +39,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; using runtime::StorageRank; using runtime::StorageScope; @@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerThreadAllreduce", LowerThreadAllreduce); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc old mode 100755 new mode 100644 index 09583f2c9..96ae34e3f --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -435,7 +435,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return expr; } if (const auto *var_node = expr.as()) { - Var var = GetRef(var_node); + Var var = tvm::ffi::GetRef(var_node); auto it = let_bindings_.find(var); if (it != let_bindings_.end()) { return it->second; @@ -611,7 +611,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { let_bindings_.erase(op->var); } if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = value; @@ -652,7 +652,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (call && call->op.as()) return Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - auto tile_op = ParseOperator(GetRef(op), buffer_data_to_buffer_); + auto tile_op = + ParseOperator(tvm::ffi::GetRef(op), buffer_data_to_buffer_); if (!tile_op.defined()) return IRMutatorWithAnalyzer::VisitStmt_(op); AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { @@ -730,10 +731,10 @@ tvm::transform::Pass LowerTileOp() { return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/make_packed_api.cc b/src/transform/make_packed_api.cc index b03193c8c..b0a67e6d5 100644 --- a/src/transform/make_packed_api.cc +++ b/src/transform/make_packed_api.cc @@ -42,6 +42,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; static constexpr const char *kDeviceContextVar = "device_api_context"; namespace { @@ -168,7 +169,7 @@ class SubroutineCallRewriter : public StmtExprMutator { auto node = Downcast(StmtExprMutator::VisitExpr_(op)); if (auto *gvar_ptr = node->op.as()) { - auto gvar = GetRef(gvar_ptr); + auto gvar = tvm::ffi::GetRef(gvar_ptr); if (auto symbol = packed_func_methods.Get(gvar)) { Array cpacked_args; cpacked_args.push_back(tir::StringImm(symbol.value())); @@ -220,7 +221,7 @@ Optional RequiresPackedAPI(const PrimFunc &func) { // Internal function calls do not need the PackedFunc API auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - if (!global_symbol.defined()) { + if (!global_symbol) { return std::nullopt; } @@ -229,7 +230,7 @@ Optional RequiresPackedAPI(const PrimFunc &func) { PrimFunc MakePackedAPI(PrimFunc func) { auto global_symbol = RequiresPackedAPI(func); - if (!global_symbol.defined()) { + if (!global_symbol) { return func; } std::string name_hint = global_symbol.value(); @@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { StringImm(name_hint + "_compute_"), body); // Set device context if (vmap.count(device_id.get())) { - ObjectRef node = String("default"); + auto node = String("default"); seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop)); seq_check.push_back( AttrStmt(node, tir::attr::device_type, device_type, nop)); @@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() { return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.MakePackedAPI", []() { return MakePackedAPI(); }); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/merge_if_stmt.cc b/src/transform/merge_if_stmt.cc index db0206e4c..39ea3b0b7 100644 --- a/src/transform/merge_if_stmt.cc +++ b/src/transform/merge_if_stmt.cc @@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() { return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index f558fdbc8..f2175efe0 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -162,7 +162,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { // set into scope_.size() - 1 for aggressive memory reuse auto enable_aggressive_merge = enable_aggressive_merge_; if (enable_aggressive_merge) { @@ -209,7 +209,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { // the merged allocator can reason about their lifetime correctly. ICHECK_LE(it->second.level, scope_.size()) << "Load memory in places other than store."; - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { auto enable_aggressive_merge = enable_aggressive_merge_; if (enable_aggressive_merge) { scope_[scope_.size() - 1].touched.push_back(buf); @@ -233,7 +233,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { // emitted at the allocation level after flattening, so accept them and // record the touch for liveness planning. ICHECK_LE(it->second.level, scope_.size()); - if (IsAppropriateSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(tvm::ffi::GetRef(buf))) { auto enable_aggressive_merge = enable_aggressive_merge_; if (enable_aggressive_merge) { scope_[scope_.size() - 1].touched.push_back(buf); @@ -372,7 +372,7 @@ class SharedMemoryAlignmentPlanner : public StmtExprVisitor { void VisitExpr_(const VarNode *op) { auto ptr_type = op->type_annotation.as(); if (ptr_type && under_alignment_scope_) { - auto scope = GetPtrStorageScope(GetRef(op)); + auto scope = GetPtrStorageScope(tvm::ffi::GetRef(op)); if (scope == "shared" || scope == "shared.dyn") { auto target = Target::Current(); ICHECK(target.defined()) << "Target is not defined"; @@ -1343,11 +1343,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false, {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations", MergeSharedMemoryAllocations); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index 38c9108c3..7ed9437cf 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -57,7 +57,7 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { // Check reads from global Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", - /*body*/ GetRef(op)); + /*body*/ tvm::ffi::GetRef(op)); auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto reads = access[0]; Role role = Role::kProducer; @@ -253,7 +253,8 @@ class MultiVersionBufferRewriter : public StmtExprMutator { } static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { - ObjectPtr new_buffer = make_object(*(buffer.get())); + ObjectPtr new_buffer = + tvm::ffi::make_object(*(buffer.get())); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); if (!new_buffer->strides.empty()) { ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); @@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() { return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/persist_threadblock.cc b/src/transform/persist_threadblock.cc index 56f0b4bd0..b64ffdcce 100644 --- a/src/transform/persist_threadblock.cc +++ b/src/transform/persist_threadblock.cc @@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() { return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index 15d4ff961..717dce27f 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -103,7 +103,7 @@ class AsyncDependencyChainBuilder : public StmtExprVisitor { ICHECK(call->op.same_as(builtin::tvm_access_ptr())); auto var = call->args[1].as(); ICHECK(var); - auto it = buffer_data_to_buffer_.find(GetRef(var)); + auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef(var)); ICHECK(it != buffer_data_to_buffer_.end()); return (*it).second; }; @@ -210,7 +210,7 @@ class BufferRegionCollector : public StmtExprVisitor { if (const auto *load = op->args[0].as()) { buffer_region = BufferRegion::FullRegion(load->buffer); } else if (const auto *var_node = op->args[0].as()) { - Var data_var = GetRef(var_node); + Var data_var = tvm::ffi::GetRef(var_node); auto it = buffer_data_to_buffer_.find(data_var); if (it != buffer_data_to_buffer_.end()) { buffer_region = BufferRegion::FullRegion((*it).second); @@ -223,7 +223,7 @@ class BufferRegionCollector : public StmtExprVisitor { } else if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode *buffer_var = op->args[1].as(); ICHECK(buffer_var); - auto it = buffer_data_to_buffer_.find(GetRef(buffer_var)); + auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)); if (it != buffer_data_to_buffer_.end()) { const Buffer &buffer = (*it).second; const BufferRegion buffer_region = BufferRegion::FullRegion(buffer); @@ -402,7 +402,7 @@ class PipelinePlanner : public StmtExprMutator { if (TargetHasAsyncCopy(target_) && use_async_copy_) annotations.Set(tir::attr::software_pipeline_async_stages, Array{0}); - auto for_node = GetRef(loop); + auto for_node = tvm::ffi::GetRef(loop); for_node.CopyOnWrite()->annotations = annotations; return for_node; } @@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() { return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/simplify.cc b/src/transform/simplify.cc index f1a64c306..d64c7016d 100644 --- a/src/transform/simplify.cc +++ b/src/transform/simplify.cc @@ -23,6 +23,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; using namespace arith; struct SimplifyConfigNode : public AttrsNodeReflAdapter { @@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { "branch", refl::DefaultValue(false)); } - static constexpr const char *_type_key = "tl.transform.SimplifyConfig"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig", + SimplifyConfigNode, BaseAttrsNode); RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; @@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) { class SimplifyConfig : public Attrs { public: - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, - SimplifyConfigNode); + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs, + SimplifyConfigNode); }; -TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } -TVM_REGISTER_NODE_TYPE(SimplifyConfigNode); TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { @@ -391,7 +391,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (can_inline && !used_in_buffer_def) { return body; } else if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { auto n = this->CopyOnWrite(op); n->value = std::move(value); @@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) { return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.Simplify", Simplify); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/split_host_device.cc b/src/transform/split_host_device.cc index 6e9ae914a..a9f52f41d 100644 --- a/src/transform/split_host_device.cc +++ b/src/transform/split_host_device.cc @@ -37,7 +37,7 @@ namespace tvm { namespace tl { - +using namespace ffi; namespace tir = tvm::tir; class HostDeviceSplitter : public tir::StmtMutator { @@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() { {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/storage_access.cc b/src/transform/storage_access.cc index 806414c00..67900c3a1 100644 --- a/src/transform/storage_access.cc +++ b/src/transform/storage_access.cc @@ -39,10 +39,11 @@ using namespace tir; void TileLangStorageAccessVisitor::VisitExpr_(const BufferLoadNode *op) { Var buf = op->buffer->data; - buffer_data_to_buffer_.Set(GetRef(buf.get()), op->buffer); + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); StorageScope scope = GetScope(buf); if (Enabled(buf.get(), scope)) { - ICHECK(allow_append_) << GetRef(op) << " " << scope.to_string(); + ICHECK(allow_append_) << tvm::ffi::GetRef(op) << " " + << scope.to_string(); AccessEntry e; e.threads = env_threads(); e.thread_range = this->ComputeThreadRange(e.threads); @@ -66,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) { curr_stmt_.stmt = op; Var buf = op->buffer->data; - buffer_data_to_buffer_.Set(GetRef(buf.get()), op->buffer); + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buf.get()), op->buffer); StorageScope scope = GetScope(buf); if (Enabled(buf.get(), scope)) { AccessEntry e; @@ -326,8 +327,8 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { Buffer buffer = load->buffer; DataType dtype = buffer->dtype; const VarNode *buffer_var = buffer->data.as(); - buffer_data_to_buffer_.Set(GetRef(buffer_var), buffer); - StorageScope scope = GetScope(GetRef(buffer_var)); + buffer_data_to_buffer_.Set(tvm::ffi::GetRef(buffer_var), buffer); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); Array buffer_ranges; // from indices to buffer indices ICHECK(buffer->shape.size() == load->indices.size()); @@ -365,17 +366,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode *flag = op->args[4].as(); - StorageScope scope = GetScope(GetRef(buffer_var)); + StorageScope scope = GetScope(tvm::ffi::GetRef(buffer_var)); // The buffer scope. if (Enabled(buffer_var, scope)) { ICHECK(allow_append_); Array buffer_ranges; - if (buffer_data_to_buffer_.find(GetRef(buffer_var)) == + if (buffer_data_to_buffer_.find(tvm::ffi::GetRef(buffer_var)) == buffer_data_to_buffer_.end()) { // cannot find buffer map, use the default buffer buffer_ranges = {Range::FromMinExtent(offset, extent)}; } else { - Buffer buffer = buffer_data_to_buffer_.at(GetRef(buffer_var)); + Buffer buffer = + buffer_data_to_buffer_.at(tvm::ffi::GetRef(buffer_var)); auto buffer_shape = buffer->shape; // convert 1d offset to multi-dimensional index auto linear_to_indices = [this](PrimExpr offset, @@ -406,7 +408,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { e.threads = env_threads(); e.thread_range = this->ComputeThreadRange(e.threads); e.dtype = dtype; - e.buffer = GetRef(buffer_var); + e.buffer = tvm::ffi::GetRef(buffer_var); e.buffer_ranges = buffer_ranges; e.is_pointer_access = true; e.touched = { diff --git a/src/transform/storage_access.h b/src/transform/storage_access.h index c0d0ed470..54114ace2 100644 --- a/src/transform/storage_access.h +++ b/src/transform/storage_access.h @@ -39,6 +39,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index da8f0943e..3324677c8 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -544,7 +544,7 @@ class StoragePlanRewriter : public StmtExprMutator { } return it->second->alloc_var; } else { - return GetRef(op); + return tvm::ffi::GetRef(op); } } PrimExpr VisitExpr_(const CallNode *op) final { @@ -978,8 +978,8 @@ class StoragePlanRewriter : public StmtExprMutator { ICHECK(alloc_info.count(var)); const AllocEntry &entry = alloc_info.at(var); const AllocateNode *alloc = entry.alloc; - auto storage_scope = - StorageScope::Create(GetPtrStorageScope(GetRef(var))); + auto storage_scope = StorageScope::Create( + GetPtrStorageScope(tvm::ffi::GetRef(var))); StorageEntry *dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -1732,7 +1732,7 @@ class VectorTypeRewriter : public StmtExprMutator { Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var; if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } return LetStmt(var, value, body); } @@ -1985,10 +1985,10 @@ Pass StorageRewrite() { return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.StorageRewrite", StorageRewrite); -}); +} Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) { @@ -1997,11 +1997,11 @@ Pass PointerValueTypeRewrite() { return CreatePrimFuncPass(pass_func, 0, "tl.PointerValueTypeRewrite", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.PointerValueTypeRewrite", PointerValueTypeRewrite); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index be120b62f..0627678e1 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -850,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tl.ThreadSync", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.ThreadSync", ThreadSync); -}); +} } // namespace transform } // namespace tl diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 8891b0084..ae1545796 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -43,6 +43,7 @@ namespace tvm { namespace tl { using namespace tir; +using namespace ffi; /*! * \brief Perform data type legalization on the given BufferLoadNode pointer. @@ -242,7 +243,7 @@ class TLVectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); @@ -296,7 +297,7 @@ class TLVectorizer : public StmtMutator, PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return !(a); } @@ -337,10 +338,10 @@ class TLVectorizer : public StmtMutator, PrimExpr value = this->VisitExpr(op->value); if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } if (value.same_as(op->value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Broadcast(op->value, op->lanes); } @@ -352,7 +353,7 @@ class TLVectorizer : public StmtMutator, PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor(); @@ -370,7 +371,7 @@ class TLVectorizer : public StmtMutator, PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { if (value.dtype().is_scalable_vector()) { return Cast(op->dtype.with_scalable_vscale_factor( @@ -383,20 +384,20 @@ class TLVectorizer : public StmtMutator, } PrimExpr VisitExpr_(const FloatImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const IntImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr VisitExpr_(const StringImmNode *op) final { - return GetRef(op); + return tvm::ffi::GetRef(op); } // Variable PrimExpr VisitExpr_(const VarNode *op) final { - Var var = GetRef(op); + Var var = tvm::ffi::GetRef(op); if (var.same_as(var_)) { return ramp_; @@ -413,13 +414,13 @@ class TLVectorizer : public StmtMutator, PrimExpr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); @@ -441,7 +442,7 @@ class TLVectorizer : public StmtMutator, ICHECK(op->op.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int lanes = value.dtype().get_lanes_or_vscale_factor(); if (value.dtype().is_scalable_vector()) { @@ -486,12 +487,12 @@ class TLVectorizer : public StmtMutator, auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; - return GetRef(op); + return tvm::ffi::GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Call(op->dtype, op->op, new_args); } @@ -500,7 +501,7 @@ class TLVectorizer : public StmtMutator, Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Call(op->dtype.with_lanes(lane), op->op, new_args); } @@ -508,7 +509,7 @@ class TLVectorizer : public StmtMutator, } // BufferLoad PrimExpr VisitExpr_(const BufferLoadNode *op) final { - auto load = GetRef(op); + auto load = tvm::ffi::GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); @@ -547,7 +548,7 @@ class TLVectorizer : public StmtMutator, let_binding_[op->var] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Let(op->var, value, body); } @@ -555,7 +556,7 @@ class TLVectorizer : public StmtMutator, } // BufferStore Stmt VisitStmt_(const BufferStoreNode *op) final { - auto store = GetRef(op); + auto store = tvm::ffi::GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); @@ -618,11 +619,11 @@ class TLVectorizer : public StmtMutator, ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, op->annotations); @@ -633,7 +634,7 @@ class TLVectorizer : public StmtMutator, ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = std::nullopt; @@ -642,7 +643,7 @@ class TLVectorizer : public StmtMutator, } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } @@ -667,7 +668,7 @@ class TLVectorizer : public StmtMutator, let_binding_[op->var] = op->var; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return LetStmt(op->var, value, body); } @@ -681,7 +682,7 @@ class TLVectorizer : public StmtMutator, if (condition.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; - return Scalarize(GetRef(op)); + return Scalarize(tvm::ffi::GetRef(op)); } return StmtMutator::VisitStmt_(op); @@ -746,7 +747,7 @@ class TLVectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -762,7 +763,7 @@ class TLVectorizer : public StmtMutator, PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); @@ -842,10 +843,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) { return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index b86ebaf96..fd02c0240 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -159,7 +159,7 @@ class WarpSpecializedRoleMarker : public StmtVisitor { // Check reads from global Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", - /*body*/ GetRef(op)); + /*body*/ tvm::ffi::GetRef(op)); auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); auto reads = access[0]; Role role = Role::kProducer; @@ -511,7 +511,7 @@ class GroupOpRewriter : public StmtExprMutator { annotations.Set(String("stmt_group"), Integer(1)); auto original_node = (op->body).as(); if (!original_node) { - return GetRef(op); + return tvm::ffi::GetRef(op); } Array new_body; int cur_id = 0; @@ -646,7 +646,7 @@ class WSCodeEmitter : public StmtMutator { if (role == Role::kBoth) { return StmtMutator::VisitStmt_(op); } else if ((role == Role::kProducer) == is_emitting_producer_) { - return GetRef(op); + return tvm::ffi::GetRef(op); } else { return Evaluate(0); } @@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() { return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, disable_shuffle_elect); } else { - ObjectRef node = String("default"); + auto node = ffi::String("default"); f.CopyOnWrite()->body = AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); return f; @@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() { return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.WarpSpecialized", WarpSpecialized); -}); +} } // namespace tl } // namespace tvm diff --git a/src/transform/wgmma_sync_rewriter.cc b/src/transform/wgmma_sync_rewriter.cc index 0b5a5eb39..538b49110 100644 --- a/src/transform/wgmma_sync_rewriter.cc +++ b/src/transform/wgmma_sync_rewriter.cc @@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() { return CreatePrimFuncPass(pass_func, 0, "tl.RewriteWgmmaSync", {}); } -TVM_FFI_STATIC_INIT_BLOCK({ +TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.RewriteWgmmaSync", RewriteWgmmaSync); -}); +} } // namespace tl } // namespace tvm diff --git a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py b/testing/python/jit/test_tilelang_jit_gemm_ctypes.py index 650bb2f97..fd5243f00 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py +++ b/testing/python/jit/test_tilelang_jit_gemm_ctypes.py @@ -85,7 +85,7 @@ def run_gemm( stramp = "&*(XS)" - @tvm.register_func("tilelang_callback_cuda_postproc", override=True) + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) def tilelang_callback_cuda_postproc(code, _): code = f"// {stramp}\n" + code return code @@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape(): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_gemm_f16f16f16_nn() diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index efffc0fa8..12524f129 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -85,7 +85,7 @@ def run_gemm( stramp = "&*(XS)" - @tvm.register_func("tilelang_callback_cuda_postproc", override=True) + @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) def tilelang_callback_cuda_postproc(code, _): code = f"// {stramp}\n" + code return code diff --git a/tilelang/_ffi_api.py b/tilelang/_ffi_api.py index d4fb0be49..6e6421bf7 100644 --- a/tilelang/_ffi_api.py +++ b/tilelang/_ffi_api.py @@ -1,6 +1,6 @@ """FFI APIs for tilelang""" -import tvm.ffi +import tvm_ffi # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); -tvm.ffi._init_api("tl", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("tl", __name__) diff --git a/tilelang/contrib/dlpack.py b/tilelang/contrib/dlpack.py index 58e82f8b1..e61d80cee 100644 --- a/tilelang/contrib/dlpack.py +++ b/tilelang/contrib/dlpack.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Wrapping functions to bridge frameworks with DLPack support to TVM""" -from tvm.runtime import ndarray +from tvm import runtime def convert_func(tvm_func, tensor_type, to_dlpack_func): @@ -49,9 +49,9 @@ def adapt_tensor(arg): torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz }: - return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( + return runtime.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view( arg.shape, dtype=float8_dtype_map[arg.dtype]) - return ndarray.from_dlpack(to_dlpack_func(arg)) + return runtime.from_dlpack(to_dlpack_func(arg)) return arg def _wrapper(*args): diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index 92fbcc8e3..4e3c9a5c3 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -9,7 +9,7 @@ import subprocess -import tvm.ffi +import tvm_ffi from tvm.contrib import utils from tvm.base import py_str @@ -96,7 +96,7 @@ def compile_hip(code, return data -@tvm.ffi.register_func("tilelang_callback_hip_compile", override=True) +@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) def tilelang_callback_hip_compile(code, target): """use hipcc to generate fatbin code for better optimization""" hsaco = compile_hip(code, target_format="hsaco") diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 8e813d92b..7d2e9d56b 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -8,8 +8,8 @@ import subprocess import warnings from tilelang.env import CUDA_HOME - -import tvm.ffi +import tvm_ffi +from tilelang import tvm as tvm from tvm.target import Target from tvm.base import py_str @@ -182,14 +182,14 @@ def get_cuda_version(cuda_path=None): raise RuntimeError("Cannot read cuda version file") -@tvm.ffi.register_func("tilelang_callback_cuda_compile", override=True) +@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) def tilelang_callback_cuda_compile(code, target): # pylint: disable=unused-argument """use nvcc to generate fatbin code for better optimization""" ptx = compile_cuda(code, target_format="fatbin") return ptx -@tvm.ffi.register_func("tilelang_callback_libdevice_path", override=True) +@tvm_ffi.register_global_func("tilelang_callback_libdevice_path", override=True) def find_libdevice_path(arch): """Utility function to find libdevice @@ -254,7 +254,7 @@ def callback_libdevice_path(arch): return "" -@tvm.ffi.register_func("tvm.contrib.nvcc.get_compute_version", override=True) +@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version", override=True) def get_target_compute_version(target=None): """Utility function to get compute capability of compilation target. @@ -400,7 +400,7 @@ def have_cudagraph(): return False -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_bf16", override=True) +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16", override=True) def have_bf16(compute_version): """Either bf16 support is provided in the compute capability or not @@ -413,7 +413,7 @@ def have_bf16(compute_version): return major >= 8 -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_fp8", override=True) +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8", override=True) def have_fp8(compute_version): """Whether fp8 support is provided in the specified compute capability or not @@ -430,7 +430,7 @@ def have_fp8(compute_version): return any(conditions) -@tvm.ffi.register_func("tvm.contrib.nvcc.supports_tma", override=True) +@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_tma", override=True) def have_tma(target): """Whether TMA support is provided in the specified compute capability or not diff --git a/tilelang/contrib/rocm.py b/tilelang/contrib/rocm.py index 8bb9e1d85..4a57c3c64 100644 --- a/tilelang/contrib/rocm.py +++ b/tilelang/contrib/rocm.py @@ -21,7 +21,7 @@ import os from os.path import join, exists -import tvm.ffi +import tvm_ffi from tvm.base import py_str import tvm.runtime import tvm.target @@ -100,7 +100,7 @@ def rocm_link(in_file, out_file, lld=None): raise RuntimeError(msg) -@tvm.ffi.register_func("tvm_callback_rocm_link", override=True) +@tvm_ffi.register_global_func("tvm_callback_rocm_link", override=True) def callback_rocm_link(obj_bin): """Links object file generated from LLVM to HSA Code Object @@ -124,7 +124,7 @@ def callback_rocm_link(obj_bin): return cobj_bin -@tvm.ffi.register_func("tvm_callback_rocm_bitcode_path", override=True) +@tvm_ffi.register_global_func("tvm_callback_rocm_bitcode_path", override=True) def callback_rocm_bitcode_path(rocdl_dir=None): """Utility function to find ROCm device library bitcodes @@ -226,7 +226,7 @@ def have_matrixcore(compute_version=None): return False -@tvm.ffi.register_func("tvm_callback_rocm_get_arch", override=True) +@tvm_ffi.register_global_func("tvm_callback_rocm_get_arch", override=True) def get_rocm_arch(rocm_path="/opt/rocm"): """Utility function to get the AMD GPU architecture diff --git a/tilelang/engine/callback.py b/tilelang/engine/callback.py index ee1c80693..05fafe9db 100644 --- a/tilelang/engine/callback.py +++ b/tilelang/engine/callback.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Callable -from tvm import register_func +import tvm_ffi from tvm.target import Target @@ -12,7 +12,7 @@ def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = and returns the processed code (str). override: Whether to override existing registered function. Defaults to True. """ - register_func("tilelang_callback_cuda_postproc", f=func, override=override) + tvm_ffi.register_global_func("tilelang_callback_cuda_postproc", f=func, override=override) def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True): @@ -23,7 +23,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T and returns the processed code (str). override: Whether to override existing registered function. Defaults to True. """ - register_func("tilelang_callback_hip_postproc", f=func, override=override) + tvm_ffi.register_global_func("tilelang_callback_hip_postproc", f=func, override=override) def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 8738f58a1..d0c27b4c2 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -7,6 +7,7 @@ import tilelang.transform from tilelang import tvm as tvm from tvm import tir +import tvm_ffi from tvm.ir import CallingConv from tvm.target import Target from tilelang.contrib import hipcc, nvcc @@ -52,7 +53,7 @@ def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]: return lambda func: not get_device_call(is_device_c)(func) -@tvm.register_func("tilelang_callback_cuda_compile", override=True) +@tvm_ffi.register_global_func("tilelang_callback_cuda_compile", override=True) def tilelang_callback_cuda_compile(code, target): project_root = osp.join(osp.dirname(__file__), "../..") if "TL_TEMPLATE_PATH" in os.environ: @@ -89,7 +90,7 @@ def tilelang_callback_cuda_compile(code, target): return ptx -@tvm.register_func("tilelang_callback_hip_compile", override=True) +@tvm_ffi.register_global_func("tilelang_callback_hip_compile", override=True) def tilelang_callback_hip_compile(code, target): project_root = osp.join(osp.dirname(__file__), "../..") tl_template_path = osp.abspath(osp.join(project_root, "src")) @@ -181,7 +182,7 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> elif target.kind.name == "llvm": device_mod = tvm.ffi.get_global_func("target.build.llvm")(device_mod, target) elif target.kind.name == "webgpu": - device_mod = tvm.ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target) + device_mod = tvm.ffi.get_global_func("target.build.webgpu")(device_mod, target) elif target.kind.name == "metal": device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target) else: @@ -240,6 +241,6 @@ def lower( host_mod = host_codegen(host_mod, target_host) host_mod.import_module(codegen_mod) return CompiledArtifact( - host_mod, device_mod, params, codegen_mod.get_source(), rt_mod=host_mod) + host_mod, device_mod, params, codegen_mod.inspect_source(), rt_mod=host_mod) - return CompiledArtifact(host_mod, device_mod, params, codegen_mod.get_source()) + return CompiledArtifact(host_mod, device_mod, params, codegen_mod.inspect_source()) diff --git a/tilelang/ir.py b/tilelang/ir.py index d48aeeed8..cccf97e0a 100644 --- a/tilelang/ir.py +++ b/tilelang/ir.py @@ -1,32 +1,32 @@ from tilelang import tvm as tvm from tvm.ir.base import Node from tvm.runtime import Scriptable -import tvm.ffi +import tvm_ffi from tvm.target import Target from tilelang import _ffi_api -@tvm.ffi.register_object("tl.Fill") +@tvm_ffi.register_object("tl.Fill") class Fill(Node, Scriptable): ... -@tvm.ffi.register_object("tl.AtomicAdd") +@tvm_ffi.register_object("tl.AtomicAdd") class AtomicAdd(Node, Scriptable): ... -@tvm.ffi.register_object("tl.Copy") +@tvm_ffi.register_object("tl.Copy") class Copy(Node, Scriptable): ... -@tvm.ffi.register_object("tl.Conv2DIm2Col") +@tvm_ffi.register_object("tl.Conv2DIm2Col") class Conv2DIm2ColOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.GemmWarpPolicy") +@tvm_ffi.register_object("tl.GemmWarpPolicy") class GemmWarpPolicy(Node, Scriptable): policy_type: int m_warp: int @@ -39,41 +39,41 @@ def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target return self.m_warp, self.n_warp -@tvm.ffi.register_object("tl.Gemm") +@tvm_ffi.register_object("tl.Gemm") class Gemm(Node, Scriptable): ... -@tvm.ffi.register_object("tl.GemmSP") +@tvm_ffi.register_object("tl.GemmSP") class GemmSP(Node, Scriptable): ... -@tvm.ffi.register_object("tl.FinalizeReducerOp") +@tvm_ffi.register_object("tl.FinalizeReducerOp") class FinalizeReducerOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.ParallelOp") +@tvm_ffi.register_object("tl.ParallelOp") class ParallelOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.ReduceOp") +@tvm_ffi.register_object("tl.ReduceOp") class ReduceOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.CumSumOp") +@tvm_ffi.register_object("tl.CumSumOp") class CumSumOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.RegionOp") +@tvm_ffi.register_object("tl.RegionOp") class RegionOp(Node, Scriptable): ... -@tvm.ffi.register_object("tl.ReduceType") +@tvm_ffi.register_object("tl.ReduceType") class ReduceType(Node, Scriptable): ... diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index b9c2b10ec..06fc7a987 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -3,13 +3,14 @@ from __future__ import annotations import tvm +import tvm_ffi from tvm.ir import Range from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tilelang import _ffi_api from tilelang.layout import Layout -@tvm.ffi.register_object("tl.Fragment") +@tvm_ffi.register_object("tl.Fragment") class Fragment(Layout): """ A Fragment layout object that encapsulates iteration variables (forward_vars), diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index dd0f11709..14db12223 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -2,14 +2,14 @@ # pylint: disable=invalid-name, unsupported-binary-operation from __future__ import annotations -import tvm +import tvm_ffi from tvm.ir import Node, Range from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tilelang import _ffi_api # Register the Layout class as a TVM object under the name "tl.Layout" -@tvm.ffi.register_object("tl.Layout") +@tvm_ffi.register_object("tl.Layout") class Layout(Node): def __init__(self, shape, forward_fn): diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index d0ea704cc..178fc96dc 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -4,7 +4,7 @@ from tvm.target import Target from tvm.ir.base import Node from tvm.runtime import Scriptable -import tvm.ffi +import tvm_ffi from tilelang.ir import GemmWarpPolicy from .gemm_mma import GemmMMA from .gemm_wgmma import GemmWGMMA @@ -12,13 +12,13 @@ from tilelang import _ffi_api -@tvm.ffi.register_func("tl.gemm_py.infer_layout") +@tvm_ffi.register_global_func("tl.gemm_py.infer_layout") def gemm_py_infer_layout(gemm_py, target, thread_bounds): thread_nums = thread_bounds.extent return gemm_py.infer_layout(target, thread_nums) -@tvm.ffi.register_func("tl.gemm_py.lower") +@tvm_ffi.register_global_func("tl.gemm_py.lower") def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): thread_nums = thread_bounds.extent stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) @@ -46,7 +46,7 @@ def is_mfma(self) -> bool: return self == GemmInst.MFMA -@tvm.ffi.register_object("tl.GemmPy") +@tvm_ffi.register_object("tl.GemmPy") class GemmPy(Node, Scriptable): A: tir.Buffer B: tir.Buffer diff --git a/tilelang/transform/_ffi_api.py b/tilelang/transform/_ffi_api.py index c89dddda1..3692a32d6 100644 --- a/tilelang/transform/_ffi_api.py +++ b/tilelang/transform/_ffi_api.py @@ -1,6 +1,6 @@ """FFI APIs for tilelang""" -import tvm.ffi +import tvm_ffi # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); -tvm.ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("tl.transform", __name__) diff --git a/tilelang/utils/tensor.py b/tilelang/utils/tensor.py index 9d0c3c3a4..51f63db4a 100644 --- a/tilelang/utils/tensor.py +++ b/tilelang/utils/tensor.py @@ -2,7 +2,7 @@ """The profiler and convert to torch utils""" from enum import Enum import torch -from tvm.runtime import ndarray +from tvm import runtime from tvm import tir from torch.utils.dlpack import to_dlpack import numpy as np @@ -49,9 +49,9 @@ def adapt_torch2tvm(arg): if arg.dtype in { torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz }: - return ndarray.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view( + return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view( shape=arg.shape, dtype=float8_dtype_map[arg.dtype]) - return ndarray.from_dlpack(to_dlpack(arg)) + return runtime.from_dlpack(to_dlpack(arg)) return arg