A shared middle-layer for the Triton Compiler.
Currently the middle layer is not complete but has enough functionality to demonstrate how it can work. The general idea is that Triton IR is lowered into an MLIR core dialect to allow it to be both shared across Triton targets as well as allow back-ends to be shared with other languages.
The basic intended architecture looks like this:
[Triton IR] -> [Middle Layer] -> [HW specific IR]
The middle-layer uses MLIR's Linalg and Tensor Dialects for operations on Triton block values. Operations on Triton pointers use the Memref Dialect.
This talk at the 2023 Triton Developer Conferene gives some background on the project and its goals.
This repo now includes triton
as a submodule and builds as an out-of-tree backend.
To build this repo clone triton-shared
to a folder called triton_shared
(notice the underscore).
Triton
will use this folder name to create a module under triton.runtime
for the reference CPU backend.
You need to set the TRITON_PLUGINS_DIRS
environment variable to the location of your triton-shared
directory for triton
to find it.
export TRITON_PLUGIN_DIRS=$(pwd)/triton_shared
git clone --recurse-submodules https://github.com/microsoft/triton-shared.git triton_shared
cd triton_shared/triton/python
To build with Clang:
python3 -m pip install --upgrade pip
python3 -m pip install cmake==3.24 ninja pytest-xdist
sudo apt-get update -y
sudo apt-get install -y ccache clang lld
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true python3 -m pip install --no-build-isolation -vvv '.[tests]'
To build with a virtualenv:
python3 -m venv .venv --prompt triton
source .venv/bin/activate
pip3 install ninja cmake wheel pytest
pip3 install -e python --no-build-isolation
The resulting triton-shared
binaries will be placed under triton/python/build/{current_cmake_version}/third_party/triton_shared
The middle layer can be used as a stand-alone component to convert Triton dialect to the middle layer dialects. This is intended for testing and validation purposes, but could potentially be used before sending the IR to another MLIR complier.
Stand-alone example:
triton-shared-opt --triton-to-linalg %file
The intended use of the Triton middle layer is to be used as a component in a Triton back-end. This can be accomplished by adding the cmake targets it produces and its headers files to that back-end. An example back-end will be published at a later date.
We also include an experimental reference CPU backend that leverages all existing mlir
passes. After building, the CPU backend can be used by setting triton
's active driver:
import triton
from triton.backends.triton_shared.driver import CPUDriver
triton.runtime.driver.set_active(CPUDriver())
For more examples, please refer to python/examples
.
Even though a valid triton program can perform load and store in arbitrary memory locations, the prototype only supports lowering programs that have structured memory access patterns.
As part of the conversion process, there are three important analyses:
-
Pointer analysis:
- This analysis is responsible for extracting structured memory access patterns from a
triton
program during load and store; it walks the IR and visits relevant instructions to build strided memory accesses in thememref
dialect. The analysis is still in its early stage and does not support all scenarios.
- This analysis is responsible for extracting structured memory access patterns from a
-
Use analysis:
- After "Pointer analysis", instructions that are part of memory address calculation will no longer be necessary in a triton program because their semantics have now been captured by
memref
operations representing strided memory accesses. To aid with removing these instructions safely, we performUse analysis
to mark which instructions are used only in address calculation (calledMetaUse
) or used in both address calculation and data manipulation (calledMixedUse
) operations. Those that areMixedUse
are cloned and have their users adjusted accordingly with the goal of separating out theMetaUse
ops so that they can be safely deleted.
- After "Pointer analysis", instructions that are part of memory address calculation will no longer be necessary in a triton program because their semantics have now been captured by
-
Mask analysis:
- This analysis is responsible for handling masked loads and stores.
We introduce the TritonToLinalg
pass that converts the triton
dialect to the linalg
dialect on tensors. This means the resulting IR is fully compatible with linalg
tiling and fusion transformation passes. As mentioned in the Pointer analysis
's description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using bufferization.to_tensor
. Here's a simple example of what the IR looks like:
tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %afloat : (!tt.ptr<bf16>) -> tensor<128x!tt.ptr<bf16>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<bf16>>, tensor<128xi32>
%afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16>
%3 = "tt.reduce"(%afm) ({
^bb0(%arg5: bf16, %arg6: bf16):
%21 = arith.addf %arg5, %arg6 : bf16
tt.reduce.return %21 : bf16
}) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16
tt.store %res, %3 : bf16
tt.return
}
after conversion:
func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) {
%cst = arith.constant 0.000000e+00 : f32
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] :
memref<*xbf16> to memref<128xbf16, strided<[1]>>
%alloc = memref.alloc() : memref<128xbf16>
memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16>
%0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16>
%1 = bufferization.alloc_tensor() : tensor<f32>
%inserted = tensor.insert %cst into %1[] : tensor<f32>
%reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor<f32>) dimensions = [0]
(%in: bf16, %init: f32) {
%3 = arith.extf %in : bf16 to f32
%4 = arith.addf %3, %init : f32
linalg.yield %4 : f32
}
%extracted = tensor.extract %reduced[] : tensor<f32>
%2 = arith.truncf %extracted : f32 to bf16
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] :
memref<*xbf16> to memref<1xbf16, strided<[1]>>
affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>>
return
}
Important details to note:
-
tt.load
(together with all of its related address calculation instructions such astt.addptr
andtt.splat
) are lowered to a combination ofmemref.reinterpret_cast
,memref.alloc
, andmemref.copy
. After the initialization of the local buffer, we convert the memref back to a tensor usingbufferization.to_tensor
; this op is automatically removed during bufferization. -
tt.store
lowers to a combination ofmemref.reinterpret_cast
and eitheraffine.store
ormemref.tensor_store
:
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
bufferization.materialize_in_destination %extracted_slice in writable %subview
- element-wise
arith
andmath
operators are converted to their correspondinglinalg.generic
version. tt.dot
becomeslinalg.matmul
.tt.reduce
becomeslinalg.reduce
; known limitation: only supportaddf
andmaxf
reduction in the reduction body for now.
The prototype was tested on the following triton kernel examples:
- vector addition
- fused softmax
- matrix multiplication
- layer normalization
- fused attention
The Python tests are setup to run with Pytest and you will need to set the following environment variables to run them:
export LLVM_BINARY_DIR=<path-to-your-llvm-binaries>
export TRITON_SHARED_OPT_PATH=$TRITON_PLUGINS_DIR/triton/python/build/<your-cmake-directory>/third_party/triton_shared_opt/triton_shared-opt
pytest <path-to-triton-shared>/python/examples
In addition to testing on the tutorial kernels, there are many lit tests covering various scenarios.
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.