diff --git a/.github/workflows/_build.yml b/.github/workflows/_build.yml new file mode 100644 index 0000000000..3bbd5f0a4f --- /dev/null +++ b/.github/workflows/_build.yml @@ -0,0 +1,225 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.27 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + sub-packages: '["nvcc"]' + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + # detect if we're on ARM + if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then + PLAT=linux_aarch64 + else + PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 + fi + echo "PLAT=$PLAT" >> $GITHUB_ENV + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 + pip install jinja2 + TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl + TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl + pip install --no-cache-dir --pre "${TRITON_URL}" + pip install --no-cache-dir --pre "${TORCH_URL}" + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export FLASH_ATTENTION_FORCE_BUILD="TRUE" + export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000000..25ea5e86b7 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,47 @@ +name: Build wheels + +on: + workflow_dispatch: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4746c71493..e88090f336 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,16 +13,16 @@ on: - v* jobs: - setup_release: name: Create Release runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} shell: bash - - name: Create Release id: create_release uses: actions/create-release@v1 @@ -35,184 +35,59 @@ jobs: build_wheels: name: Build Wheel needs: setup_release - runs-on: ${{ matrix.os }} - strategy: fail-fast: false matrix: - # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the - # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. - os: [ubuntu-20.04] - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] - # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. - # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. - # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) - # when building without C++11 ABI and using it on nvcr images. - cxx11_abi: ['FALSE', 'TRUE'] - exclude: - # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '2.1.2' - python-version: '3.12' - # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.1.2' - python-version: '3.13' - - torch-version: '2.2.2' - python-version: '3.13' - - torch-version: '2.3.1' - python-version: '3.13' - - torch-version: '2.4.0' - python-version: '3.13' - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Set CUDA and PyTorch versions - run: | - echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - - - name: Free up disk space - if: ${{ runner.os == 'Linux' }} - # https://github.com/easimon/maximize-build-space/blob/master/action.yml - # https://github.com/easimon/maximize-build-space/tree/test-report - run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - - - name: Set up swap space - if: runner.os == 'Linux' - uses: pierotofy/set-swap-space@v1.0 - with: - swap-size-gb: 10 - - - name: Install CUDA ${{ matrix.cuda-version }} - if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.19 - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda-version }} - linux-local-args: '["--toolkit"]' - # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 - # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} - method: 'network' - sub-packages: '["nvcc"]' - - - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} - run: | - pip install --upgrade pip - # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==68.0.0 - # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error - # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable - pip install typing-extensions==4.12.2 - # We want to figure out the CUDA version to download pytorch - # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04, ubuntu-22.04-arm] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + include: + - torch-version: "2.9.0.dev20250904" + cuda-version: "13.0" + exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ - print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ - ) - if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - # Can't use --no-deps because we need cudnn etc. - # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 - pip install jinja2 - pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} - fi - nvcc --version - python --version - python -c "import torch; print('PyTorch:', torch.__version__)" - python -c "import torch; print('CUDA:', torch.version.cuda)" - python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - shell: - bash - - - name: Build wheel - run: | - # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 - # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 - # However this still fails so I'm using a newer version of setuptools - pip install setuptools==68.0.0 - pip install ninja packaging wheel - export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH - export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Limit MAX_JOBS otherwise the github runner goes OOM - # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist - tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") - ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - - - name: Log Built Wheels - run: | - ls dist - - - name: Get the tag version - id: extract_branch - run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} - - - name: Get Release with tag - id: get_current_release - uses: joutvhu/get-release@v1 - with: - tag_name: ${{ steps.extract_branch.outputs.branch }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Upload Release Asset - id: upload_release_asset - uses: actions/upload-release-asset@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./dist/${{env.wheel_name}} - asset_name: ${{env.wheel_name}} - asset_content_type: application/* + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true publish_package: name: Publish package needs: [build_wheels] - runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 with: - python-version: '3.10' - + python-version: "3.10" - name: Install dependencies run: | - pip install ninja packaging setuptools wheel twine + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu - - name: Build core package env: FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist - - name: Deploy env: TWINE_USERNAME: "__token__" diff --git a/.gitmodules b/.gitmodules index 6216182e72..a6446cc597 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,4 @@ [submodule "csrc/composable_kernel"] path = csrc/composable_kernel url = https://github.com/ROCm/composable_kernel.git + branch = amd-master diff --git a/MANIFEST.in b/MANIFEST.in index 021b4d0f7d..d3c4b4eda1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp recursive-include csrc *.hpp +recursive-include csrc *.py recursive-include flash_attn *.cu recursive-include flash_attn *.h diff --git a/README.md b/README.md index 033dba4100..dd7f1c1646 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Currently released: Requirements: H100 / H800 GPU, CUDA >= 12.3. -For now, we highly recommend CUDA 12.3 for best performance. +We highly recommend CUDA 12.8 for best performance. To install: ```sh @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 1.12 and above. +- PyTorch 2.2 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. @@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation ### NVIDIA CUDA Support **Requirements:** -- CUDA 11.7 and above. +- CUDA 12.0 and above. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) @@ -137,38 +137,74 @@ These features are supported in Fwd and Bwd 2) Variable sequence lengths 3) Arbitrary Q and KV sequence lengths 4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings +8) ALiBi -These features are supported in Fwd for now. We will add them to backward soon. -1) Multi and grouped query attention -2) ALiBi and matrix bias - -These features are in development +We are working on the following things 1) Paged Attention 2) Sliding Window -3) Rotary embeddings -4) Dropout -5) Performance Improvements +3) FP8 +4) Performance Improvements -#### Getting Started +##### Getting Started To get started with the triton backend for AMD, follow the steps below. -First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4). +First install the recommended Triton version ``` -git clone https://github.com/triton-lang/triton -cd triton -git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 -pip install --verbose -e python +pip install triton==3.2.0 ``` -Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. +Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. ``` -export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" cd flash-attention -python setup.py install -pytest tests/test_flash_attn.py +git checkout main_perf +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install +``` + +To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py ``` +You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE +``` + +###### Docker +You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. +``` +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install triton +RUN pip install triton==3.2.0 + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention +``` + +To build the docker file +``` +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` ## How to use FlashAttention diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py new file mode 100644 index 0000000000..b3902110ee --- /dev/null +++ b/benchmarks/benchmark_attn.py @@ -0,0 +1,411 @@ +from collections import namedtuple +from functools import partial +import math +import os +from typing import NamedTuple +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None +# cudnn = None + +Timing = NamedTuple('timing', [('mean', float)]) + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func +from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python +from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python +try: + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 +except ImportError: + flash_attn_func_v3 = None + flash_attn_varlen_func_v3 = None + +if torch.cuda.get_device_capability()[0] != 9: + flash_attn_func_v3 = None +# flash_attn_func_v3 = None + +flash_attn_func = None + +from triton.testing import do_bench + +def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): + # # Warmup + # for _ in range(5): + # func(*args, **kwargs) + # time.sleep(1) + # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] + # s = torch.cuda.Stream() + # s.wait_stream(torch.cuda.current_stream()) + # with torch.cuda.stream(s): + # for _ in range(2): + # out = func(*args, **kwargs) + # torch.cuda.current_stream().wait_stream(s) + # graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(graph): + # out = func(*args, **kwargs) + # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) + # # return time_f[1].mean + # return time_f[1] + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)): + if causal: + avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 + else: + if window_size == (None, None): + avg_seqlen = seqlen_k + else: + row_idx = torch.arange(seqlen_q, device='cuda') + col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + avg_seqlen = (col_right - col_left + 1).float().mean().item() + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_spda_setup(q, k, v, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu = torch.empty((b, nheads, seqlen_q, headdim_v), dtype=q.dtype, device=q.device) + stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + + o, stats = graph.sdpa( + name="sdpa", + q=q, + k=k, + v=v, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + stats: stats_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_k, seqlen_k, _ = k.shape + headdim_v = v.shape[-1] + assert v.shape == (b, nheads_k, seqlen_k, headdim_v) + assert g.shape == (b, nheads, seqlen_q, headdim_v) + assert o.shape == (b, nheads, seqlen_q, headdim_v) + assert lse.shape == (b, nheads, seqlen_q, 1) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g + dq_gpu = torch.empty_like(q_gpu) + dk_gpu = torch.empty_like(k_gpu) + dv_gpu = torch.empty_like(v_gpu) + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = graph.tensor_like(q_gpu.detach()) + k = graph.tensor_like(k_gpu.detach()) + v = graph.tensor_like(v_gpu.detach()) + o = graph.tensor_like(o_gpu.detach()) + g = graph.tensor_like(g_gpu.detach()) + stats = graph.tensor_like(lse.detach()) + + dq, dk, dv = graph.sdpa_backward( + name="sdpa_backward", + q=q, + k=k, + v=v, + o=o, + dO=g, + stats=stats, + attn_scale=1.0 / math.sqrt(headdim), + # use_causal_mask_bottom_right=causal or window_size_left is not None, + use_causal_mask=causal or window_size_left is not None, + sliding_window_length=window_size_left if window_size_left is not None and not causal else None, + ) + + dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) + dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) + dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + g: g_gpu, + stats: lse, + dq: dq_gpu, + dk: dk_gpu, + dv: dv_gpu, + } + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(*args, **kwargs): + graph.execute(variant_pack, workspace) + return dq_gpu, dk_gpu, dv_gpu + + return run + + +torch.manual_seed(0) +repeats = 10 +dropout_p = 0.0 +causal = False +dtype = torch.bfloat16 +# dtype = torch.float8_e4m3fn +dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype +device = 'cuda' +verbose = True +varlen = False +has_backward = False +page_size = None +# page_size = 128 +softcap = 0.0 +V_colmajor = False +deterministic = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 +# for headdim in [64, 128, 256]: +# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] +bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)] +# bs_seqlen_vals = [(32, 512), (16, 1024)] +# bs_seqlen_vals = [(2, 64 * 132)] +# bs_seqlen_vals = [(4, 8192)] +# bs_seqlen_vals = [(1, 16 * 1024)] +time_f = {} +time_b = {} + +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192]: +# for headdim in [64, 96, 128, 192, 256]: +# for headdim in [64, 96, 128]: +# for headdim in [64, 128, 256]: +# for headdim in [64, 96, 128, 192, 256]: +for headdim in [128]: + # nheads = dim // headdim + nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8 + # nheads = 128 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + # nheads_kv = nheads + nheads_kv = nheads // 8 + # nheads_kv = 1 + # headdim_v = headdim + headdim_v = 128 if headdim == 192 else headdim + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False + # sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + sinks = None + + for batch_size, seqlen in bs_seqlen_vals: + num_splits = 0 + # window_size = (-1, -1) + window_size = (None, None) + window_size_fa = (-1, -1) + # window_size = (seqlen // 2 - 1, 0) + pack_gqa = None + # seqlen_q = 64 + seqlen_q = seqlen + leftpad_k = None + # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) + q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] + v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_(has_backward) + v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None + # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) + stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) + if varlen: + q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] + cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q + cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None + # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:256] + # seqlen_q = 256 + # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) + # q_unpad = q_unpad[:384] + # seqlen_q = 384 + if page_size is not None: + assert seqlen % page_size == 0 + k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + + # for causal in [False, True]: + for causal in [True]: + print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if has_backward and headdim == headdim_v: + cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + if not varlen: + m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + else: + m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') + time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean + if has_backward: + time.sleep(1) + if not varlen: + _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + else: + _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav2') + time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) + + if cudnn is not None: + # if False: + if headdim <= 256 and dtype != torch.float8_e4m3fn: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') + time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean + if has_backward: + time.sleep(1) + m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean + # pytorch_profiler(cudnn_spda, backward=False) + # pytorch_profiler(cudnn_spda_bwd, backward=False) + time.sleep(1) + if flash_attn_func_v3 is not None: + if not varlen: + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) + else: + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) + time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean + if flash_attn_func_python is not None: + if not varlen: + m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + else: + m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward: + time.sleep(1) + if not varlen: + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3') + else: + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + repeats=repeats, verbose=False, desc='Fav3') + time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean + # time.sleep(1) + # if not varlen: + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) + # else: + # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) + # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward: + _, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python') + + if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None: + # if False: + print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') + if has_backward: + print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') + if flash_attn_func_v3 is not None: + print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') + + if flash_attn_func_python is not None: + print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS') + if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward: + print(f'FAv2 Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS') diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 6c4797c83e..c97581c658 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -17,12 +17,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func -try: - from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax -except ImportError: - scaled_upper_triang_masked_softmax = None - - def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: @@ -52,27 +46,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): return output.to(dtype=qkv.dtype) -def attention_megatron(qkv): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - Output: - output: (batch_size, seqlen, nheads, head_dim) - """ - batch_size, seqlen, _, nheads, d = qkv.shape - q, k, v = qkv.unbind(dim=2) - q = rearrange(q, 'b t h d -> (b h) t d') - k = rearrange(k, 'b s h d -> (b h) d s') - softmax_scale = 1.0 / math.sqrt(d) - # Preallocate attn_weights for `baddbmm` - scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) - scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), - '(b h) t s -> b h t s', h=nheads) - attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0) - output = torch.einsum('bhts,bshd->bthd', attention, v) - return output.to(dtype=qkv.dtype) - - torch.manual_seed(0) repeats = 30 batch_size = 8 @@ -130,9 +103,6 @@ def attention_megatron(qkv): # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') # # pytorch_profiler(attention, q, k, v, 1.0, backward=True) -# if scaled_upper_triang_masked_softmax is not None: -# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') - # from src.ops.fftconv import fftconv_func # dim = nheads * headdim diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index df0d56b8f2..3f3639e0b5 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs torch.manual_seed(0) repeats = 30 -dtype = torch.float16 +dtype = torch.bfloat16 device = 'cuda' verbose = False m, n = 8192, 8192 diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 888317e698..e8709c24f4 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 888317e698e9803c62bd38568abc9e05d7709f33 +Subproject commit e8709c24f403173ad21a2da907d1347957e324fb diff --git a/csrc/cutlass b/csrc/cutlass index c506e16788..dc4817921e 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit c506e16788cb08416a4a57e11a9067beeee29420 +Subproject commit dc4817921edda44a549197ff3a9dcf5df0636e7b diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index b8158fc940..a7b5d36835 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -432,7 +432,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -515,7 +515,7 @@ std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. @@ -644,7 +644,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -831,7 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -1048,7 +1048,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); @@ -1321,7 +1321,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); - const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index e34dd2454b..0000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 5089d988d9..0000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 0272c57975..0000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu deleted file mode 100644 index d3d5d98d12..0000000000 --- a/csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_bwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 8f42f0ae10..50af5f6307 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -118,7 +118,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index b719cf9887..72e7a333b3 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -102,7 +102,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -261,26 +261,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - int device; - cudaGetDevice(&device); - int max_smem_per_block; - cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); - if (status_ != cudaSuccess) { - C10_CUDA_CHECK(status_); - } - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_bwd, Is_dropout, Is_causal>(params, stream); - } - }); -} - template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index 016a010709..e4875fe3a1 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -79,7 +79,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; @@ -205,7 +205,7 @@ inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index 27d9e9d8a7..0000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu deleted file mode 100644 index 943e508eb1..0000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 92904627b9..0000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu deleted file mode 100644 index 7b3749e255..0000000000 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); -} - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 1ba07da157..d492c87b5c 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -362,7 +362,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -424,7 +424,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -922,7 +922,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); @@ -987,7 +987,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 227f3c2572..934e7b9114 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -76,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; @@ -117,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { @@ -165,7 +165,6 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. - // Also for headdim 160 with block size 64 x 128 after the rotary addition. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd, Is_causal>(params, stream); } @@ -257,34 +256,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int Headdim = 160; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); -} - template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu deleted file mode 100644 index f5167b3339..0000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu deleted file mode 100644 index ee02db1a34..0000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu deleted file mode 100644 index 2b0472038f..0000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu deleted file mode 100644 index 2b833bd537..0000000000 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2024, Tri Dao. -// Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" -#include "flash_fwd_launch_template.h" - -namespace FLASH_NAMESPACE { - -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); - -} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 7b2130babb..834bd22bd0 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -10,7 +10,7 @@ } SM = [80] # Sm80 kernels support up to -HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256] +HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] IS_CAUSAL = ["false", "true"] NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a57702f6ce..70d14daf69 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -101,9 +101,6 @@ } else if (HEADDIM <= 128) { \ constexpr static int kHeadDim = 128; \ return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ } else if (HEADDIM <= 192) { \ constexpr static int kHeadDim = 192; \ return __VA_ARGS__(); \ diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index a386768216..68e2835518 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -19,6 +19,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, dtype, false, // is_group_mode true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -111,6 +112,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -134,6 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index bcb8e3bbb9..27866f1902 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -33,7 +33,8 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask, head_size, dtype, false, // is_group_mode - true, // is_v_rowmajor + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 6274750f58..3e4422efec 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -17,8 +17,9 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, return fmha_fwd_traits{head_size, head_size, dtype, - true, // is_group_mode - true, // is_v_rowmajor + true, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -35,8 +36,9 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m return fmha_fwd_splitkv_traits{head_size, head_size, dtype, - true, // is_group_mode - true, // is_v_rowmajor + true, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -131,6 +133,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, softmax_scale, // scale_s 1, // scale_p 1, // scale_o + 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, @@ -154,6 +157,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/csrc/ft_attention/README.md b/csrc/ft_attention/README.md deleted file mode 100644 index 97feb78cc1..0000000000 --- a/csrc/ft_attention/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Attention kernel from FasterTransformer - -This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from -FasterTransformer v5.2.1 for benchmarking purpose. - -```sh -cd csrc/ft_attention && pip install . -``` - -As of 2023-09-17, this extension is no longer used in the FlashAttention repo. -FlashAttention now has implemented -[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py) -with all the features of this `ft_attention` kernel (and more). - diff --git a/csrc/ft_attention/cuda_bf16_fallbacks.cuh b/csrc/ft_attention/cuda_bf16_fallbacks.cuh deleted file mode 100644 index f5641f6160..0000000000 --- a/csrc/ft_attention/cuda_bf16_fallbacks.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include - -namespace fastertransformer { - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; -} - -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace fastertransformer diff --git a/csrc/ft_attention/cuda_bf16_wrapper.h b/csrc/ft_attention/cuda_bf16_wrapper.h deleted file mode 100644 index efb6e79873..0000000000 --- a/csrc/ft_attention/cuda_bf16_wrapper.h +++ /dev/null @@ -1,23 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu deleted file mode 100644 index 13306f7686..0000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,149 +0,0 @@ -// Adapted from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include -#include -#include - -#include "decoder_masked_multihead_attention_template.hpp" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - auto kernel = mmha::masked_multihead_attention_kernel; \ - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - dim3 grid(params.nnz_head_idx == nullptr ? params.num_heads : params.nnz_heads, params.batch_size); \ - kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); - if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); - } - else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); - } - else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#undef MMHA_LAUNCH_KERNEL - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h deleted file mode 100644 index 3c79f88b85..0000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,192 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride_q = 0; - int stride_k = 0; - int stride_v = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - int num_heads_kv = 0; - int num_heads_q_kv_ratio = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - float rotary_base = 0.0f; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; - - const T *rotary_cos = nullptr; - const T *rotary_sin = nullptr; - - const int *nnz_head_idx = nullptr; - int nnz_heads = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 2ae1b2425b..0000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1619 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "decoder_masked_multihead_attention.h" -#include "decoder_masked_multihead_attention_utils.h" -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -#define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_ { -}; - -template<> -struct Qk_vec_ { - using Type = float; -}; -template<> -struct Qk_vec_ { - using Type = float2; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = float4; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_ { - using Type = uint2; -}; -template<> -struct Qk_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_ { -}; - -template<> -struct K_vec_ { - using Type = float; -}; -template<> -struct K_vec_ { - using Type = float2; -}; -template<> -struct K_vec_ { - using Type = float4; -}; -template<> -struct K_vec_ { - using Type = uint32_t; -}; -template<> -struct K_vec_ { - using Type = uint2; -}; -template<> -struct K_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_ { -}; - -template<> -struct V_vec_ { - using Type = float; -}; -template<> -struct V_vec_ { - using Type = float2; -}; -template<> -struct V_vec_ { - using Type = float4; -}; -template<> -struct V_vec_ { - using Type = uint32_t; -}; -template<> -struct V_vec_ { - using Type = uint2; -}; -template<> -struct V_vec_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(T); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) -{ - - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(T) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; - } - T* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - T* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - - // Use alignment for safely casting the shared buffers as Qk_vec. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; - - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec = typename Qk_vec_::Type; - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - // The head. - // const int hi = blockIdx.x; - const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x]; - const int hi_kv = hi / params.num_heads_q_kv_ratio; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bhi_kv = bi * params.num_heads_kv + hi_kv; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv; - // The thread in the block. - const int tidx = threadIdx.x; - - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh; - int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh; - int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - int q_offset = q_base_offset + tidx * QK_VEC_SIZE; - int k_offset = k_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = *reinterpret_cast(¶ms.q[q_offset]); - } - } - - Qk_vec k; - zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - *reinterpret_cast(¶ms.k_cache[offset]) : - k; - } - else { - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = *reinterpret_cast(¶ms.k[k_offset]); - } - } - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec q_bias; - zero(q_bias); - q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : - q_bias; - - Qk_vec k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - else { - if (params.rotary_cos == nullptr) { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - if (params.rotary_cos == nullptr) { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); - } else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, - params.rotary_cos + bi * params.rotary_embedding_dim / 2, - params.rotary_sin + bi * params.rotary_embedding_dim / 2); - } - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = k; - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec = typename K_vec_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // The keys loaded from the key cache. - K_vec k[K_VECS_PER_THREAD]; - K_vec k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (has_beams) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); - } - else { - k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); - } - } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - *reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE])); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec = typename V_vec_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - - // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = v_bias; - } - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; - // Load the values from the cache. - V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, *reinterpret_cast(&bias_smem[vi])); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = v; - } - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else - T logit = logits_smem[ti - first_step]; - - // Update the partial sums. - out = fma(logit, v, out); -#endif - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec v; - if (DO_CROSS_ATTENTION) { - v = *reinterpret_cast(&v_cache[tlength * Dh]); - } - else { - // Trigger the loads from the V buffer. - const auto v_offset = v_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = *reinterpret_cast(¶ms.v[v_offset]); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - } - - // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - if (hi % params.num_heads_q_kv_ratio == 0) { - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; - } - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else - // out = fma(logits_smem[params.timestep], v, out); - out = fma(logits_smem[tlength - first_step], v, out); -#endif - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h deleted file mode 100644 index 98875aba9b..0000000000 --- a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h +++ /dev/null @@ -1,2017 +0,0 @@ -// Downloaded from from FasterTransformer v5.2.1 -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_bf16_wrapper.h" -#include "cuda_bf16_fallbacks.cuh" -#include - -using namespace fastertransformer; - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct num_elems; -template<> -struct num_elems { - static constexpr int value = 1; -}; -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -template<> -struct num_elems { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; - -#ifdef ENABLE_BF16 -template<> -struct num_elems<__nv_bfloat162> { - static constexpr int value = 2; -}; -template<> -struct num_elems { - static constexpr int value = 4; -}; -template<> -struct num_elems { - static constexpr int value = 8; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct packed_type; -template -struct packed_type { - using type = T; -}; -template<> -struct packed_type { - using type = int16_t; -}; -template<> -struct packed_type { - using type = int32_t; -}; -template<> -struct packed_type { - using type = int64_t; -}; - -template<> -struct packed_type { - using type = float2; -}; -template<> -struct packed_type { - using type = float4; -}; -template<> -struct packed_type { - using type = Float8_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, uint16_t b) -{ - return a + half_to_float(b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float add(float a, __nv_bfloat16 b) -{ - return a + __bfloat162float(b); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(float a, Float8_ b) -{ - Float8_ c; - c.x = make_float2(a * b.x.x, a * b.x.y); - c.y = make_float2(a * b.y.x, a * b.y.y); - c.z = make_float2(a * b.z.x, a * b.z.y); - c.w = make_float2(a * b.w.x, a * b.w.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, float b) -{ - return half_to_float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, float b) -{ - return __bfloat162float(a) * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const int t_step, const float base) -{ - const float pos_idx_inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); - return {cos(pos_idx_inv_freq), sin(pos_idx_inv_freq)}; -} - -inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) -{ - float2 rot_v; - rot_v.x = coef.x * v.x - coef.y * v.y; - rot_v.y = coef.x * v.y + coef.y * v.x; - return rot_v; -} - -inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) -{ - float2 fv = half2_to_float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return float2_to_half2(rot_fv); -} - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) -{ - float2 fv = bf1622float2(v); - float2 rot_fv = rotary_embedding_transform(fv, coef); - return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); -} -#endif - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const T* rotary_cos, const T* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(rotary_cos[zid / 2]), float(rotary_sin[zid / 2])}; -} - -// fp16 is special because we use uint16_t for reading the data, for backward compatibility. -template <> -inline __device__ float2 rotary_embedding_coefficient(const int zid, const int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - // zid is the index of the dimension (0, 2, 4, ..., rotary_dim). - // rotary_cos/sin stores those at index 0, 1, 2, ..., rotary_dim / 2. - return {float(reinterpret_cast(rotary_cos)[zid / 2]), - float(reinterpret_cast(rotary_sin)[zid / 2])}; -} - -inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - return; -} - -inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float* rotary_cos, const float* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - - Float4_& q_ = *reinterpret_cast(&q); - Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q_.x = rotary_embedding_transform(q_.x, coef0); - k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q_.y = rotary_embedding_transform(q_.y, coef1); - k_.y = rotary_embedding_transform(k_.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const uint16_t* rotary_cos, const uint16_t* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} - -#ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); -} - -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (2 * tid >= rot_embed_dim) { - return; - } - const auto coef = rotary_embedding_coefficient(2 * tid, t_step, rotary_cos, rotary_sin); - q = rotary_embedding_transform(q, coef); - k = rotary_embedding_transform(k, coef); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (4 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(4 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); -} - -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const __nv_bfloat16* rotary_cos, const __nv_bfloat16* rotary_sin) -{ - if (8 * tid >= rot_embed_dim) { - return; - } - const auto coef0 = rotary_embedding_coefficient(8 * tid, t_step, rotary_cos, rotary_sin); - q.x = rotary_embedding_transform(q.x, coef0); - k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, t_step, rotary_cos, rotary_sin); - q.y = rotary_embedding_transform(q.y, coef1); - k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, t_step, rotary_cos, rotary_sin); - q.z = rotary_embedding_transform(q.z, coef2); - k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, t_step, rotary_cos, rotary_sin); - q.w = rotary_embedding_transform(q.w, coef3); - k.w = rotary_embedding_transform(k.w, coef3); -} -#endif // ENABLE_BF16 - -template -__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - - vec = tmp_3.u32x2; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u16[0] = tmp_1.u16[0]; - tmp_3.u16[1] = tmp_2.u16[0]; - tmp_3.u16[2] = tmp_1.u16[1]; - tmp_3.u16[3] = tmp_2.u16[1]; - tmp_3.u16[4] = tmp_1.u16[2]; - tmp_3.u16[5] = tmp_2.u16[2]; - tmp_3.u16[6] = tmp_1.u16[3]; - tmp_3.u16[7] = tmp_2.u16[3]; - - vec = tmp_3.u32x4; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - __nv_bfloat16 bf16[2]; - } tmp_1, tmp_2; - tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; -} - -template<> -__device__ __inline__ void -vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - __nv_bfloat16 bf16[4]; - } tmp_1, tmp_2; - tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); - tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); - - vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; - vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; - vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; - vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; -} -#endif // ENABLE_BF16 - -template<> -__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.z = smem[transpose_idx + 1]; - vec.y = smem[smem_pitch + transpose_idx]; - vec.w = smem[smem_pitch + transpose_idx + 1]; -} - -template<> -__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - tmp.u16[0] = smem[transpose_idx]; - tmp.u16[1] = smem[smem_pitch + transpose_idx]; - - vec = tmp.u32; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} -#endif - -template<> -__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - vec.x = smem[transpose_idx]; - vec.y = smem[smem_pitch + transpose_idx]; -} - -template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); - -template<> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint64_t u64; - uint16_t u16[4]; - } tmp_1, tmp_2; - - union { - uint4 u32x4; - uint16_t u16[8]; - } tmp_3; - tmp_3.u32x4 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - tmp_1.u16[2] = tmp_3.u16[4]; - tmp_2.u16[2] = tmp_3.u16[5]; - tmp_1.u16[3] = tmp_3.u16[6]; - tmp_2.u16[3] = tmp_3.u16[7]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp_1, tmp_2; - - union { - uint2 u32x2; - uint16_t u16[4]; - } tmp_3; - tmp_3.u32x2 = vec; - tmp_1.u16[0] = tmp_3.u16[0]; - tmp_2.u16[0] = tmp_3.u16[1]; - tmp_1.u16[1] = tmp_3.u16[2]; - tmp_2.u16[1] = tmp_3.u16[3]; - - *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; - *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = vec; - - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -template<> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[transpose_idx + 1] = vec.z; - smem[smem_pitch + transpose_idx] = vec.y; - smem[smem_pitch + transpose_idx + 1] = vec.w; -} - -template<> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) -{ - union { - uint32_t u32; - half u16[2]; - } tmp; - - tmp.u32 = vec; - smem[transpose_idx] = tmp.u16[0]; - smem[smem_pitch + transpose_idx] = tmp.u16[1]; -} - -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); -} -#endif - -template<> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) -{ - smem[transpose_idx] = vec.x; - smem[smem_pitch + transpose_idx] = vec.y; -} - -} // namespace mmha diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp deleted file mode 100644 index 886da9729b..0000000000 --- a/csrc/ft_attention/ft_attention.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" -#include - - -#include "decoder_masked_multihead_attention.h" - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") - -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ - if (TYPE == at::ScalarType::Half) { \ - using scalar_t = at::Half; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::BFloat16) { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (TYPE == at::ScalarType::Float) { \ - using scalar_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ - } - -template -void masked_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -void cross_multihead_attention(const Masked_multihead_attention_params& params, - const cudaStream_t& stream); - -template -struct SATypeConverter { - using Type = T; -}; - -template<> -struct SATypeConverter { - using Type = uint16_t; -}; - -template<> -struct SATypeConverter { - using Type = __nv_bfloat16; -}; - -template -void set_params(Masked_multihead_attention_params ¶ms, - const size_t batch_size, - const size_t nheads, - const size_t nheads_kv, - const size_t memory_max_seqlen, - const size_t headdim, - const int timestep, - const int rotary_embedding_dim, - const float rotary_base, - const bool neox_rotary_style, - const int q_batch_stride, - const int k_batch_stride, - const int v_batch_stride, - const int nnz_heads, - T *q_ptr, - T *k_ptr, - T *v_ptr, - T *k_cache_ptr, - T *v_cache_ptr, - int *length_per_sample, - T *rotary_cos, - T *rotary_sin, - T *out_ptr, - int *nnz_head_idx) { - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.q_bias = nullptr; - params.k_bias = nullptr; - params.v_bias = nullptr; - params.k_cache = k_cache_ptr; - params.v_cache = v_cache_ptr; - params.out = out_ptr; - params.cache_indir = nullptr; - params.stride_q = q_batch_stride; - params.stride_k = k_batch_stride; - params.stride_v = v_batch_stride; - params.batch_size = batch_size; - params.beam_width = 1; - params.memory_max_len = memory_max_seqlen; - params.num_heads = nheads; - params.num_heads_kv = nheads_kv; - params.num_heads_q_kv_ratio = nheads / nheads_kv; - params.nnz_heads = nnz_heads; - params.hidden_size_per_head = headdim; - params.rotary_embedding_dim = rotary_embedding_dim; - params.rotary_base = rotary_base; - params.neox_rotary_style = neox_rotary_style; - params.timestep = timestep; - params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); - params.total_padding_tokens = nullptr; - params.masked_tokens = nullptr; - params.prefix_prompt_lengths = nullptr; - params.max_prefix_prompt_length = 0; - params.relative_attention_bias = nullptr; - params.relative_attention_bias_stride = 0; - params.cross_attention_out = nullptr; - params.max_decoder_seq_len = 0; - params.is_return_cross_attentions = false; - params.finished = nullptr; - params.memory_length_per_sample = nullptr; - params.length_per_sample = length_per_sample; - params.rotary_cos = rotary_cos; - params.rotary_sin = rotary_sin; - params.nnz_head_idx = nnz_head_idx; -} - -torch::Tensor single_query_attention(const torch::Tensor q, - const torch::Tensor k, - const torch::Tensor v, - torch::Tensor k_cache, - torch::Tensor v_cache, - std::optional length_per_sample_, - std::optional rotary_cos_, - std::optional rotary_sin_, - std::optional nnz_head_idx_, - const int timestep, - int rotary_embedding_dim = 0, - const float rotary_base = 10000.0f, - const bool neox_rotary_style=true) { - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); - int batch_size = v_cache.size(0); - int nheads = q.size(1); - int nheads_kv = v_cache.size(1); - int memory_max_seqlen = v_cache.size(2); - int headdim = v_cache.size(3); - auto input_type = q.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - - CHECK_SHAPE(q, batch_size, nheads, headdim); - CHECK_SHAPE(k, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v, batch_size, nheads_kv, headdim); - CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); - // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 - int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; - CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); - TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); - TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); - TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); - CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); - - TORCH_CHECK(q.scalar_type() == input_type); - TORCH_CHECK(k.scalar_type() == input_type); - TORCH_CHECK(v.scalar_type() == input_type); - TORCH_CHECK(k_cache.scalar_type() == input_type); - TORCH_CHECK(v_cache.scalar_type() == input_type); - - if (length_per_sample_.has_value()) { - auto length_per_sample = length_per_sample_.value(); - CHECK_DEVICE(length_per_sample); - CHECK_SHAPE(length_per_sample, batch_size); - CHECK_CONTIGUOUS(length_per_sample); - TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); - rotary_embedding_dim = rotary_cos.size(-1) * 2; - CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_cos); - TORCH_CHECK(rotary_cos.scalar_type() == input_type); - - TORCH_CHECK(rotary_sin_.has_value()); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); - CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2); - CHECK_CONTIGUOUS(rotary_sin); - TORCH_CHECK(rotary_sin.scalar_type() == input_type); - } - - if (nnz_head_idx_.has_value()) { - auto nnz_head_idx = nnz_head_idx_.value(); - CHECK_DEVICE(nnz_head_idx); - int nnz_heads = nnz_head_idx.size(0); - CHECK_SHAPE(nnz_head_idx, nnz_heads); - CHECK_CONTIGUOUS(nnz_head_idx); - TORCH_CHECK(nnz_head_idx.dtype() == torch::kInt32); - } - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - torch::Tensor out = torch::empty_like(q); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { - using DataType = typename SATypeConverter::Type; - Masked_multihead_attention_params params; - set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, rotary_base, neox_rotary_style, - q.stride(0), k.stride(0), v.stride(0), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, - reinterpret_cast(q.data_ptr()), - reinterpret_cast(k.data_ptr()), - reinterpret_cast(v.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(v_cache.data_ptr()), - length_per_sample_.has_value() - ? length_per_sample_.value().data_ptr() : nullptr, - rotary_cos_.has_value() - ? reinterpret_cast(rotary_cos_.value().data_ptr()) : nullptr, - rotary_sin_.has_value() - ? reinterpret_cast(rotary_sin_.value().data_ptr()) : nullptr, - reinterpret_cast(out.data_ptr()), - nnz_head_idx_.has_value() ? nnz_head_idx_.value().data_ptr() : nullptr - ); - auto stream = at::cuda::getCurrentCUDAStream(); - masked_multihead_attention(params, stream); - }); - return out; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("single_query_attention", &single_query_attention, "Attention with a single query", - py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), - py::arg("length_per_sample_"), py::arg("rotary_cos_"), - py::arg("rotary_sin_"), py::arg("nnz_head_idx_"), - py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); -} diff --git a/csrc/ft_attention/setup.py b/csrc/ft_attention/setup.py deleted file mode 100644 index fa385ad768..0000000000 --- a/csrc/ft_attention/setup.py +++ /dev/null @@ -1,153 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -from setuptools import setup, find_packages -import subprocess - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--ft_attention") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("ft_attention is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="ft_attention", - sources=[ - "ft_attention.cpp", - "decoder_masked_multihead_attention.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-DENABLE_BF16", # TODO - "-O3", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="ft_attention", - version="0.1", - description="Attention for single query from FasterTransformer", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/fused_softmax/fused_softmax.cpp b/csrc/fused_softmax/fused_softmax.cpp deleted file mode 100644 index 2aaed91331..0000000000 --- a/csrc/fused_softmax/fused_softmax.cpp +++ /dev/null @@ -1,148 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - - m.def("scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); - - m.def("scaled_masked_softmax_get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); - - m.def("scaled_upper_triang_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("scaled_upper_triang_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/csrc/fused_softmax/scaled_masked_softmax.h b/csrc/fused_softmax/scaled_masked_softmax.h deleted file mode 100644 index 14b9f6e424..0000000000 --- a/csrc/fused_softmax/scaled_masked_softmax.h +++ /dev/null @@ -1,528 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - // compute scale value to account for full mask - acc_t scale_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0; - } - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] * scale_value[i]/ sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 12: // 4096 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 13: // 8192 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 12: // 4096 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 13: // 8192 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_masked_softmax_cuda.cu deleted file mode 100644 index a08e752699..0000000000 --- a/csrc/fused_softmax/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,121 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches - ); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - void* input_grads_ptr = static_cast(input_grads.data_ptr()); - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(input_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads - ); - ); - return input_grads; -} -} -} -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h b/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index 21e93fb313..0000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,529 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 12: // 4096 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 13: // 8192 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index 79ec30be36..0000000000 --- a/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,98 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 8192); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/fused_softmax/setup.py b/csrc/fused_softmax/setup.py deleted file mode 100644 index 9c1c6ed76e..0000000000 --- a/csrc/fused_softmax/setup.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron -# We add the case where seqlen = 4k and seqlen = 8k -import os -import subprocess - -import torch -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -cc_flag = [] -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") - -setup( - name='fused_softmax_lib', - ext_modules=[ - CUDAExtension( - name='fused_softmax_lib', - sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'], - extra_compile_args={ - 'cxx': ['-O3',], - 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag) - } - ) - ], - cmdclass={ - 'build_ext': BuildExtension -}) diff --git a/csrc/fused_softmax/type_shim.h b/csrc/fused_softmax/type_shim.h deleted file mode 100644 index 815ec7ec88..0000000000 --- a/csrc/fused_softmax/type_shim.h +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ -switch(TYPE) \ -{ \ -case at::ScalarType::Half: \ - { \ -using scalar_t = at::Half; \ -__VA_ARGS__; \ -break; \ - } \ -case at::ScalarType::BFloat16: \ - { \ -using scalar_t = at::BFloat16; \ -__VA_ARGS__; \ -break; \ - } \ -default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -} diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp deleted file mode 100644 index 640eea423a..0000000000 --- a/csrc/rotary/rotary.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include - -#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj); - -void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - CHECK_DEVICE(x1); CHECK_DEVICE(x2); - CHECK_DEVICE(cos); CHECK_DEVICE(sin); - CHECK_DEVICE(out1); CHECK_DEVICE(out1); - TORCH_CHECK(x1.dtype() == x2.dtype()); - TORCH_CHECK(cos.dtype() == sin.dtype()); - TORCH_CHECK(out1.dtype() == out2.dtype()); - TORCH_CHECK(x1.dtype() == cos.dtype()); - TORCH_CHECK(x1.dtype() == out1.dtype()); - TORCH_CHECK(x1.sizes() == x2.sizes()); - TORCH_CHECK(cos.sizes() == sin.sizes()); - TORCH_CHECK(out1.sizes() == out2.sizes()); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{x1.device()}; - - apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); -} diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu deleted file mode 100644 index 2dd0ff3f6e..0000000000 --- a/csrc/rotary/rotary_cuda.cu +++ /dev/null @@ -1,45 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include - -void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, - const torch::Tensor cos, const torch::Tensor sin, - torch::Tensor out1, torch::Tensor out2, - const bool conj) { - auto iter = at::TensorIteratorConfig() - .add_output(out1) - .add_output(out2) - .add_input(x1) - .add_input(x2) - .add_input(cos) - .add_input(sin) - .check_all_same_dtype(false) - .promote_inputs_to_common_dtype(false) - .build(); - - if (!conj) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); - scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { - at::native::gpu_kernel_multiple_outputs( - iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, - scalar_t sin) -> thrust::tuple { - scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); - scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); - return {out1, out2}; - }); - }); - } -} \ No newline at end of file diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py deleted file mode 100644 index 24d328d9c6..0000000000 --- a/csrc/rotary/setup.py +++ /dev/null @@ -1,126 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -raise_if_cuda_home_none("rotary_emb") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - 'rotary_emb', [ - 'rotary.cpp', - 'rotary_cuda.cu', - ], - extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], - 'nvcc': append_nvcc_threads([ - '-O3', '--use_fast_math', '--expt-extended-lambda' - ] + cc_flag) - } - ) -) - -setup( - name="rotary_emb", - version="0.1", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md deleted file mode 100644 index 1bc90fdab7..0000000000 --- a/csrc/xentropy/README.md +++ /dev/null @@ -1,14 +0,0 @@ -This CUDA extension implements optimized cross-entropy loss, adapted from Apex's -[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). -We make it work for bfloat16 and support in-place backward to save memory. - -It has only been tested on A100s. - -```sh -cd csrc/xentropy && pip install . -``` - -As of 2023-09-15, this extension is no longer used in the FlashAttention repo. -We've instead switched to a Triton-based -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/cross_entropy.py). -See the CrossEntropyLoss [module](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py) for more details. diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp deleted file mode 100644 index 41a783fd0f..0000000000 --- a/csrc/xentropy/interface.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include - -// CUDA forward declarations -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const int total_classes=-1) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - CHECK_INPUT(input); - CHECK_INPUT(labels); - - return softmax_xentropy_cuda(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes=-1) { - CHECK_INPUT(grad_loss); - CHECK_INPUT(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, - smoothing, inplace, total_classes); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); -} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py deleted file mode 100644 index 5079b4f384..0000000000 --- a/csrc/xentropy/setup.py +++ /dev/null @@ -1,139 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -from packaging.version import parse, Version - -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - torch_binary_version = parse(torch.version.cuda) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != torch_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.2"): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" - return nvcc_extra_args + ["--threads", nvcc_threads] - return nvcc_extra_args - - -if not torch.cuda.is_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" - elif bare_metal_version >= Version("11.1"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - elif bare_metal_version == Version("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split(".")[0]) -TORCH_MINOR = int(torch.__version__.split(".")[1]) - -cmdclass = {} -ext_modules = [] - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -raise_if_cuda_home_none("--xentropy") -# Check, if CUDA11 is installed for compute capability 8.0 -cc_flag = [] -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) -if bare_metal_version < Version("11.0"): - raise RuntimeError("xentropy is only supported on CUDA 11 and above") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_70,code=sm_70") -cc_flag.append("-gencode") -cc_flag.append("arch=compute_80,code=sm_80") -if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - -ext_modules.append( - CUDAExtension( - name="xentropy_cuda_lib", - sources=[ - "interface.cpp", - "xentropy_kernel.cu" - ], - extra_compile_args={ - "cxx": ["-O3"] + generator_flag, - "nvcc": append_nvcc_threads( - ["-O3"] - + generator_flag - + cc_flag - ), - }, - include_dirs=[this_dir], - ) -) - -setup( - name="xentropy_cuda_lib", - version="0.1", - description="Cross-entropy loss", - ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if ext_modules else {}, -) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu deleted file mode 100644 index 66aab0007b..0000000000 --- a/csrc/xentropy/xentropy_kernel.cu +++ /dev/null @@ -1,758 +0,0 @@ -// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu -// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). -/** - * From PyTorch: - * - * Copyright (c) 2016- Facebook, Inc (Adam Paszke) - * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - * Copyright (c) 2011-2013 NYU (Clement Farabet) - * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - * - * From Caffe2: - * - * Copyright (c) 2016-present, Facebook Inc. All rights reserved. - * - * All contributions by Facebook: - * Copyright (c) 2016 Facebook Inc. - * - * All contributions by Google: - * Copyright (c) 2015 Google Inc. - * All rights reserved. - * - * All contributions by Yangqing Jia: - * Copyright (c) 2015 Yangqing Jia - * All rights reserved. - * - * All contributions from Caffe: - * Copyright(c) 2013, 2014, 2015, the respective contributors - * All rights reserved. - * - * All other contributions: - * Copyright(c) 2015, 2016 the respective contributors - * All rights reserved. - * - * Caffe2 uses a copyright model similar to Caffe: each contributor holds - * copyright over their contributions to Caffe2. The project versioning records - * all such contribution and copyright details. If a contributor wants to further - * mark their specific copyright on a particular contribution, they should - * indicate their copyright solely in the commit message of the change when it is - * committed. - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - * and IDIAP Research Institute nor the names of its contributors may be - * used to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#include -#include -#include - -#include -#include - -// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } -// #else -// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ -// switch(TYPE) \ -// { \ -// case at::ScalarType::Float: \ -// { \ -// using scalar_t_##LEVEL = float; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// case at::ScalarType::Half: \ -// { \ -// using scalar_t_##LEVEL = at::Half; \ -// __VA_ARGS__; \ -// break; \ -// } \ -// default: \ -// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ -// } -// #endif - -#define ALIGN_BYTES 16 - -using Tensor = at::Tensor; -using TensorList = at::TensorList; -using ScalarType = at::ScalarType; -using at::acc_type; - -template -struct LogSoftMaxForwardEpilogue { - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} - - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} - - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } - - const AccumT logsum; -}; - -template -struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} - - __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { - return static_cast(gradOutput - std::exp(static_cast(output)) * sum); - } - - const AccumT sum; -}; - - - -const int max_threads = 1024; - -inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { - uint64_t block_size = 1; - uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; - // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); - return dim3(block_size); -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// Regular kernel (fast when dim_size is large; requires inner_size == 1) -//////////////////////////////////////////////////////////////////////////////// - - -template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } -}; - -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } -}; - -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} - - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } - - const AccumT max_k; -}; - -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val; - - __syncthreads(); - - AccumT warpVal = defaultVal; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); - } - __syncwarp(mask); - smem[lane] = warpVal; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal = defaultVal; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal = r(blockVal, smem[i]); - } - smem[0] = blockVal; - } - - // Sync and broadcast - __syncthreads(); - return smem[0]; -} - -template class Reduction1, template class Reduction2, typename AccumT> -__device__ __forceinline__ void -blockReduce(AccumT* smem, - AccumT* reducVal1, - AccumT val1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - AccumT val2, - const Reduction2& r2, - AccumT defaultVal2) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val1; - smem[blockDim.x + threadIdx.x] = val2; - - __syncthreads(); - - AccumT warpVal1 = defaultVal1; - AccumT warpVal2 = defaultVal2; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { -#pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal1 = r1(warpVal1, smem[lane * 32 + i]); - warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); - } - __syncwarp(mask); - smem[lane] = warpVal1; - smem[lane + blockDim.x] = warpVal2; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal1 = defaultVal1; - AccumT blockVal2 = defaultVal2; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { - blockVal1 = r1(blockVal1, smem[i]); - blockVal2 = r2(blockVal2, smem[i + blockDim.x]); - } - smem[0] = blockVal1; - smem[blockDim.x] = blockVal2; - } - - // Sync and broadcast - __syncthreads(); - *reducVal1 = smem[0]; - *reducVal2 = smem[blockDim.x]; - __syncthreads(); -} - -template class Reduction, int ILP, typename T, typename AccumT> -__device__ __forceinline__ AccumT -ilpReduce(int shift, - T* data, - int size, - const Reduction& r, - AccumT defaultVal) -{ - typedef typename std::aligned_storage::type LoadT; - AccumT threadVal = defaultVal; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal = r(threadVal, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal = r(threadVal, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) - threadVal = r(threadVal, data[offset]); - - return threadVal; -} - -template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> -__device__ __forceinline__ void -ilpReduce(int shift, - T* data, - int size, - AccumT* reducVal1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - const Reduction2& r2, - AccumT defaultVal2) -{ - typedef typename std::aligned_storage::type LoadT; - - AccumT threadVal1 = defaultVal1; - AccumT threadVal2 = defaultVal2; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal1 = r1(threadVal1, v[j]); - threadVal2 = r2(threadVal2, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) { - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - - *reducVal1 = threadVal1; - *reducVal2 = threadVal2; -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyForward( - accscalar_t *losses, - outscalar_t *max_log_sum_exp, - scalar_t *input, - int64_t *labels, - int64_t classes, - const float smoothing, - const int total_classes) -{ - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - // forward pointers to batch[blockIdx.x] - // each block handles a sample in the mini-batch - input += blockIdx.x * classes; - //output += blockIdx.x * classes; - const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); - - int64_t label = labels[blockIdx.x]; - - // find the max and sum - accscalar_t threadMax, threadSum, max_k, sum_k; - ilpReduce( - shift, input, classes, - &threadMax, MaxFloat(), - -at::numeric_limits::max(), - &threadSum, AddFloat(), - static_cast(0)); - - blockReduce( - sdata, - &max_k, threadMax, Max(), - -at::numeric_limits::max(), - &sum_k, threadSum, Add(), - static_cast(0)); - - accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); - accscalar_t sumAll = blockReduce( - sdata, threadExp, Add(), static_cast(0)); - - Epilogue epilogue(max_k, sumAll); - - // calculate per element loss with label smoothing - // reserve max + log_sum_exp for bprop - if (threadIdx.x == 0) { - accscalar_t lse = max_k + std::log(sumAll); - accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; - losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); - max_log_sum_exp[blockIdx.x] = lse; - } -} - -template -__device__ __forceinline__ void -apply(scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - int last = classes % (ILP * blockDim.x); - - for (; offset < classes - last; offset += blockDim.x * ILP) { - accscalar_t tmpLogits[ILP]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); - } - -#pragma unroll - for (int j = 0; j < ILP; ++j) - gradInput[offset + j * blockDim.x] = tmpGradOutput * ( - std::exp(tmpLogits[j] - coeff) - static_cast( - (offset + j * blockDim.x == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast((offset == label) ? 1 : 0) * - smooth_positives - smooth_negatives); -} - - -template -__device__ __forceinline__ void -aligned_apply(int shift, - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / total_classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - logits -= shift; - gradInput -= shift; - classes += shift; - if(threadIdx.x >= shift){ - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - classes -= blockDim.x; - gradInput += blockDim.x; - logits += blockDim.x; - shift -= blockDim.x; - } - - int last = classes % (ILP * blockDim.x); - - typedef typename std::aligned_storage::type LoadT; - // input - scalar_t v[ILP]; - LoadT* value = reinterpret_cast(&v); - // output - scalar_t r[ILP]; - LoadT* result = reinterpret_cast(&r); - - for (; offset * ILP < (classes - last); offset += blockDim.x) { - *value = reinterpret_cast(logits)[offset]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - r[j] = tmpGradOutput * (std::exp( - static_cast(v[j]) - coeff) - - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - reinterpret_cast(gradInput)[offset] = *result; - } - - offset = classes - last + threadIdx.x; - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyBackward( - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes, - const int total_classes) -{ - gradInput += blockIdx.x * classes; - logits += blockIdx.x * classes; - - // Do vectorized load/store when input/output have same alignment - const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); - const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); - if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); - } - -} - -template class Epilogue> -std::vector host_softmax_xentropy( - const Tensor & input_, - const Tensor & labels_, - const float smoothing, - const int total_classes) { - // For tensor parallel cross entropy with smoothing, we want to pass in the total number - // of classes so that smoothing can be applied correctly. If total_classes=-1, use the - // last dimension of the input tensor. - AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{input_.device()}; - - auto input = input_.contiguous(); - Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); - Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); - - const int64_t dim = 1; - int64_t outer_size = 1; - int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - using namespace at; - DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", - using accscalar_t = at::acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyForward - <<>>( - losses.data_ptr(), max_log_sum_exp.data_ptr(), - input.data_ptr(), labels_.data_ptr(), - dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - - std::vector ret = {losses, max_log_sum_exp}; - return ret; -} - -template class Epilogue> -Tensor host_softmax_xentropy_backward( - const at::Tensor &grad_loss, - at::Tensor &logits_, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - bool inplace, - const int total_classes) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{grad_loss.device()}; - - const int64_t dim = 1; - Tensor gI = inplace ? logits_ : at::empty_like(logits_); - if (grad_loss.numel() == 0) { - return gI; - } - - auto grad = grad_loss.contiguous(); - auto logits = logits_.contiguous(); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - if (grad.dim() == 0) grad = grad.view(1); - - AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); - AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); - - int64_t outer_size = 1; - int64_t dim_size = logits.size(dim); - int64_t inner_size = 1; - for (int64_t i = 0; i < dim; ++i) - outer_size *= logits.size(i); - for (int64_t i = dim + 1; i < logits.dim(); ++i) - inner_size *= logits.size(i); - // See descriptions of kernels above. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", - using accscalar_t = acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - cunn_SoftMaxXEntropyBackward - <<>>( - gI.data_ptr(), logits.data_ptr(), - max_log_sum_exp.data_ptr(), - grad.data_ptr(), labels.data_ptr(), - smoothing, dim_size, total_classes - ); - ); - - C10_CUDA_CHECK(cudaGetLastError()); - return gI; -} - -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ - return host_softmax_xentropy(input, labels, smoothing, total_classes); -} - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - const bool inplace, - const int total_classes) { - AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); -} diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 07d16cd0f4..4a8a7c33f4 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.3" +__version__ = "2.8.3" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/cute/README.md b/flash_attn/cute/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash_attn/cute/__init__.py b/flash_attn/cute/__init__.py new file mode 100644 index 0000000000..f1a4ed2d21 --- /dev/null +++ b/flash_attn/cute/__init__.py @@ -0,0 +1,13 @@ +"""Flash Attention CUTE (CUDA Template Engine) implementation.""" + +from .interface import ( + flash_attn_func, + flash_attn_varlen_func, +) + +__version__ = "0.1.0" + +__all__ = [ + "flash_attn_func", + "flash_attn_varlen_func", +] diff --git a/flash_attn/cute/ampere_helpers.py b/flash_attn/cute/ampere_helpers.py new file mode 100644 index 0000000000..839f407f75 --- /dev/null +++ b/flash_attn/cute/ampere_helpers.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Type, Callable, Optional + +import cutlass +import cutlass.cute as cute + + +def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: + dtype_byte = cutlass.const_expr(dtype.width // 8) + bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) + smem_k_block_size = cutlass.const_expr( + 128 + if bytes_per_row % 128 == 0 + else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) + ) // dtype_byte + swizzle_bits = ( + 4 + if smem_k_block_size == 128 + else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) + ) + swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) + return cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), + 0, + cute.make_ordered_layout((8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)), + ) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + tCsA: cute.Tensor, + tCsB: cute.Tensor, + smem_thr_copy_A: cute.TiledCopy, + smem_thr_copy_B: cute.TiledCopy, + hook_fn: Optional[Callable] = None, + A_in_regs: cutlass.Constexpr[bool] = False, + B_in_regs: cutlass.Constexpr[bool] = False, + swap_AB: cutlass.Constexpr[bool] = False, +) -> None: + if cutlass.const_expr(swap_AB): + gemm( + tiled_mma, + acc, + tCrB, + tCrA, + tCsB, + tCsA, + smem_thr_copy_B, + smem_thr_copy_A, + hook_fn, + A_in_regs=B_in_regs, + B_in_regs=A_in_regs, + swap_AB=False, + ) + else: + tCrA_copy_view = smem_thr_copy_A.retile(tCrA) + tCrB_copy_view = smem_thr_copy_B.retile(tCrB) + if cutlass.const_expr(not A_in_regs): + cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) + if cutlass.const_expr(not B_in_regs): + cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) + for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): + if k < cute.size(tCsA.shape[2]) - 1: + if cutlass.const_expr(not A_in_regs): + cute.copy( + smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] + ) + if cutlass.const_expr(not B_in_regs): + cute.copy( + smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] + ) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + if cutlass.const_expr(k == 0 and hook_fn is not None): + hook_fn() + + +@cute.jit +def gemm_rs( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + tCsB: cute.Tensor, + smem_thr_copy_B: cute.TiledCopy, + hook_fn: Optional[Callable] = None, +) -> None: + tCrB_copy_view = smem_thr_copy_B.retile(tCrB) + cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): + cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + if cutlass.const_expr(k == 0 and hook_fn is not None): + hook_fn() diff --git a/flash_attn/cute/blackwell_helpers.py b/flash_attn/cute/blackwell_helpers.py new file mode 100644 index 0000000000..ea464168fa --- /dev/null +++ b/flash_attn/cute/blackwell_helpers.py @@ -0,0 +1,606 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cutlass_dsl import T +from cutlass._mlir.dialects import llvm + +import flash_attn.cute.mma_sm100_desc as sm100_desc + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: bool | cutlass.Boolean = False, +) -> None: + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + tiled_mma.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) + cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + + +def i64_to_i32x2(i: int) -> Tuple[int, int]: + """Convert a 64-bit integer to a tuple of two 32-bit integers.""" + return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF + + +@cute.jit +def gemm_ptx( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else None + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + if cutlass.const_expr(not is_ts): + smem_desc_a_lo = smem_desc_start_a_lo + ((cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4) + smem_desc_b_lo = smem_desc_start_b_lo + ((cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4) + # with cute.arch.elect_one(): + # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) + # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) + with cute.arch.elect_one(): + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + smem_desc_a_lo.ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + ".reg .b32 idesc;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + tCrA[None, None, k].iterator.toint().ir_value(), + smem_desc_b_lo.ir_value(), + cutlass.Int32(not zero_init or k != 0).ir_value(), + ], + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b64 smem_desc_b;\n\t" + f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@cute.jit +def gemm_ptx_loop( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + if cutlass.const_expr(not is_ts): + offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] + else: + offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))] + offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 + for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 smem_desc_a_lo, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + "mov.b32 tmem_a, $1;\n\t" + "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def gemm_ptx_partial( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA: Optional[cute.Tensor], + sB: cute.Tensor, + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + mbar_ptr: Optional[cutlass.Pointer] = None, + mbar_phase: Optional[cutlass.Int32] = None, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA is not None, "sA must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + sA_layout = sA.layout if sA is not None else tCrA.layout + sB_layout = sB.layout + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + + tCrA_layout = tCrA.layout if cutlass.const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) + offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + else: + smem_desc_start_a_lo = None + smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(smem_desc_start_a_lo).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + "mov.b32 smem_desc_a_lo, $0;\n\t" + "mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + input_args = [ + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + ] + if cutlass.const_expr(mbar_ptr is not None): + assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" + input_args.append(mbar_ptr.toint().ir_value()) + input_args.append(cutlass.Int32(mbar_phase).ir_value()) + mbar_wait_str = ( + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [$3], $4, 10000000; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + ) + else: + mbar_wait_str = "" + llvm.inline_asm( + None, + # [ + # # acc.iterator.toint().ir_value(), + # cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + # cutlass.Int32(smem_desc_start_b_lo).ir_value(), + # cutlass.Int32(not zero_init).ir_value(), + # ], + input_args, + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + f"mov.b32 tmem_a, $0;\n\t" + f"mov.b32 smem_desc_b_lo, $1;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $2, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + + "".join( + ( + # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2]) if cutlass.const_expr(mbar_ptr is None) else cute.size(tCrA.shape[2]) // 4 * 3) + ) + + mbar_wait_str + + ("".join( + ( + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + ) + for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2])) + ) if cutlass.const_expr(mbar_ptr is not None) else "") + + "}\n", + # "r,r,r", + "r,r,r" if cutlass.const_expr(mbar_ptr is None) else "r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + +@cute.jit +def gemm_ptx_partial1( + op: cute.nvgpu.tcgen05.mma.MmaOp, + acc_tmem_addr: cutlass.Constexpr[int], + tCrA: cute.Tensor, + tCrB: cute.Tensor, + sA_base_addr_for_desc: cutlass.Int32, + sA_addr_offset_for_desc: cutlass.Constexpr[int], + sA_stage: cutlass.Int32, + sB_base_addr_for_desc: cutlass.Int32, + sB_addr_offset_for_desc: cutlass.Constexpr[int], + sB_stage: cutlass.Int32, + sA_layout: Optional[cute.Layout], + sB_layout: Optional[cute.Layout], + sA_swizzle: Optional[cute.Swizzle], + sB_swizzle: cute.Swizzle, + zero_init: bool | cutlass.Boolean = False, +) -> None: + is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM + if cutlass.const_expr(not is_ts): + assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" + assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" + idesc: int = cutlass.const_expr(sm100_desc.mma_op_to_idesc(op)) + if cutlass.const_expr(not is_ts): + smem_desc_base_a: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), + sA_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) + smem_desc_base_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + smem_desc_a_hi = cutlass.const_expr(smem_desc_a_hi) + else: + smem_desc_base_a = None + smem_desc_base_a_lo, smem_desc_a_hi = None, None + smem_desc_base_b: int = cutlass.const_expr(sm100_desc.make_smem_desc_base( + cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), + sB_swizzle, + sm100_desc.Major.K if cutlass.const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN + )) + smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) + smem_desc_base_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + smem_desc_b_hi = cutlass.const_expr(smem_desc_b_hi) + mask = [cutlass.Int32(0)] * 4 + + if cutlass.const_expr(not is_ts): + offset_a = [(cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 + for k in range(cute.size(tCrA.shape[2]))] + else: + offset_a = [cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 + for k in range(cute.size(tCrA.shape[2]))] + offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] + offset_b = [(cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 + for k in range(cute.size(tCrB.shape[2]))] + offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] + + if cutlass.const_expr(not is_ts): + # smem_desc_start_a_lo = cutlass.Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) + smem_desc_start_a_lo = cutlass.const_expr(smem_desc_base_a_lo) + else: + smem_desc_start_a_lo = None + # smem_desc_start_b_lo = cutlass.Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) + smem_desc_start_b_lo = cutlass.const_expr(smem_desc_base_b_lo) + pred_str = "p" if isinstance(zero_init, cutlass.Boolean) else "0" if zero_init else "1" + if cutlass.const_expr(not is_ts): + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), + cutlass.Int32(sA_base_addr_for_desc).ir_value(), + cutlass.Int32(sA_stage).ir_value(), + # cutlass.Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), + cutlass.Int32(sB_base_addr_for_desc).ir_value(), + cutlass.Int32(sB_stage).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value() + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_acc;\n\t" + ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_a, smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" + # "mov.b32 smem_desc_a_lo, $0;\n\t" + # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" + f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" + # "mov.b32 smem_desc_b_lo, $2;\n\t" + f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" + f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $4, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + else: + llvm.inline_asm( + None, + [ + # acc.iterator.toint().ir_value(), + cutlass.Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), + cutlass.Int32(smem_desc_start_b_lo).ir_value(), + cutlass.Int32(not zero_init).ir_value(), + mask[0].ir_value(), + mask[1].ir_value(), + mask[2].ir_value(), + mask[3].ir_value() + ], + "{\n\t" + ".reg .pred leader_thread;\n\t" + ".reg .pred p;\n\t" + ".reg .b32 idesc;\n\t" + ".reg .b32 tmem_a;\n\t" + ".reg .b32 smem_desc_b_lo;\n\t" + ".reg .b32 smem_desc_b_hi;\n\t" + ".reg .b64 smem_desc_b;\n\t" + "elect.sync _|leader_thread, -1;\n\t" + f"mov.b32 idesc, {hex(idesc)};\n\t" + f"mov.b32 tmem_a, $1;\n\t" + f"mov.b32 smem_desc_b_lo, $2;\n\t" + f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + "setp.ne.b32 p, $3, 0;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + + "".join( + ( + f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" + f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" + f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + ) + for k in range(1, cute.size(tCrA.shape[2])) + ) + + "}\n", + "r,r,r,r,r,r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py new file mode 100644 index 0000000000..2914e42e2a --- /dev/null +++ b/flash_attn/cute/block_info.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +from typing import Tuple, Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute + +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +@dataclass(frozen=True) +class BlockInfo: + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + is_causal: cutlass.Constexpr[bool] + is_local: cutlass.Constexpr[bool] = False + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + + @cute.jit + def get_n_block_min_max( + self, seqlen_info: SeqlenInfoQK, m_block: cutlass.Int32 + ) -> Tuple[cutlass.Int32, cutlass.Int32]: + n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.n_block_size) + if cutlass.const_expr( + self.is_causal or (self.is_local and self.window_size_right is not None) + ): + m_idx_max = (m_block + 1) * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = n_idx if cutlass.const_expr(self.is_causal) else n_idx + self.window_size_right + n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.n_block_size)) + n_block_min = 0 + if cutlass.const_expr(self.is_local and self.window_size_left is not None): + m_idx_min = m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + n_block_min = cutlass.max(n_idx_left // self.n_block_size, 0) + return n_block_min, n_block_max + + @cute.jit + def get_n_block_min_causal_local_mask( + self, + seqlen_info: SeqlenInfoQK, + m_block: cutlass.Int32, + n_block_min: cutlass.Int32, + ) -> cutlass.Int32: + """If we have separate iterations with causal or local masking at the start, where do we stop""" + m_idx_min = m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa + n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_right = ( + n_idx + if cutlass.const_expr(not self.is_local or self.window_size_right is None) + else n_idx + self.window_size_right + ) + return cutlass.max(n_block_min, n_idx_right // self.n_block_size) + + @cute.jit + def get_n_block_min_before_local_mask( + self, + seqlen_info: SeqlenInfoQK, + m_block: cutlass.Int32, + n_block_min: cutlass.Int32, + ) -> cutlass.Int32: + """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" + if cutlass.const_expr(not self.is_local or self.window_size_left is None): + return n_block_min + else: + m_idx_max = (m_block + 1) * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + n_idx_left = n_idx - self.window_size_left + return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.n_block_size)) diff --git a/flash_attn/cute/fast_math.py b/flash_attn/cute/fast_math.py new file mode 100644 index 0000000000..943388fd29 --- /dev/null +++ b/flash_attn/cute/fast_math.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Uint32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res + + +def find_log2(x: Int32) -> Int32: + a: Int32 = Int32(31 - clz(x)) + return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. + + +@dsl_user_op +def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], + "mul.hi.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +class FastDivmod: + def __init__( + self, divisor: Int32, multipler: Uint32, shift_right: Uint32, *, loc=None, ip=None + ): + self.divisor = divisor + self.multiplier = multipler + self.shift_right = shift_right + self._loc = loc + + # called by host + @staticmethod + def create(divisor: Int32, *, loc=None, ip=None) -> "FastDivmod": + """Construct the FastDivmod object, in host code. + This precomputes some values based on the divisor and is computationally expensive. + """ + p = Uint32(31 + find_log2(divisor)) + divisor_u32 = Uint32(divisor) + multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) + shift_right = Uint32(p - 32) + return FastDivmod(divisor, multiplier, shift_right, loc=loc, ip=ip) + + @cute.jit + def div(self, dividend: Int32) -> Int32: + return ( + Int32(umulhi(dividend, self.multiplier) >> self.shift_right) + if self.divisor != 1 + else dividend + ) + + def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: + quotient = self.div(dividend) + remainder = dividend - quotient * self.divisor + return quotient, remainder + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.divisor, self.multiplier, self.shift_right]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.divisor, self.multiplier, self.shift_right], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FastDivmod(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/flash_bwd.py b/flash_attn/cute/flash_bwd.py new file mode 100644 index 0000000000..619e0408cd --- /dev/null +++ b/flash_attn/cute/flash_bwd.py @@ -0,0 +1,1137 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp +# from Cutlass C++ to Cute-DSL. +import math +from types import SimpleNamespace +from typing import Type, Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warp +import cutlass.utils.ampere_helpers as sm80_utils_basic + +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK + + +class FlashAttentionBackwardSm80: + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + m_block_size: int = 64, + n_block_size: int = 128, + num_stages_Q: int = 2, + num_stages_dO: int = 2, + num_threads: int = 256, + is_causal: bool = False, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 1, + AtomLayoutNdKV: int = 8, + AtomLayoutMdQ: int = 1, + V_in_regs: bool = False, + ): + """Initializes the configuration for a flash attention v2 kernel. + + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + """ + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 32 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.qhead_per_kvhead = qhead_per_kvhead + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.num_threads = num_threads + self.is_causal = is_causal + self.num_stages_Q = num_stages_Q + self.num_stages_dO = num_stages_dO + self.SdP_swapAB = SdP_swapAB + self.dKV_swapAB = dKV_swapAB + self.dQ_swapAB = dQ_swapAB + self.AtomLayoutMSdP = AtomLayoutMSdP + self.AtomLayoutNdKV = AtomLayoutNdKV + self.AtomLayoutMdQ = AtomLayoutMdQ + num_mma_warps = self.num_threads // cute.arch.WARP_SIZE + self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB + self.V_in_regs = V_in_regs + self.share_QV_smem = V_in_regs + + @staticmethod + def can_implement( + dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO, + num_threads, is_causal, + V_in_regs=False + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 + smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 + smem_usage_K = n_block_size * head_dim * 2 + smem_usage_V = n_block_size * head_dim_v * 2 + smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K + smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + if smem_usage > smem_capacity: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mdO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric], + mdPsum_type: Type[cutlass.Numeric], + mdQaccum_type: Type[cutlass.Numeric], + mdK_type: Type[cutlass.Numeric], + mdV_type: Type[cutlass.Numeric], + ): + if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr(self.qhead_per_kvhead == 1): + if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)): + raise TypeError("mdK and mdV tensors must have the same data type as mQ") + else: + if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)): + raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") + if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if cutlass.const_expr(not mLSE_type in [cutlass.Float32]): + raise TypeError("LSE tensor must be Float32") + if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]): + raise TypeError("dPsum tensor must be Float32") + if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2), + ) + sK_layout_atom = sQ_layout_atom + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1), + ) + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1), + ) + sdO_layout_atom = sV_layout_atom + self.sdO_layout = cute.tile_to_shape( + sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2), + ) + # TODO: do we set swizzle to be 3 here explicitly? + sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size) + self.sPdS_layout = cute.tile_to_shape( + sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), + ) + # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, + # it's still a valid smem address. + self.sLSE_layout = cute.make_layout( + (self.m_block_size, self.num_stages_Q), + stride=(1, cute.round_up(self.m_block_size, 64)), + ) + sLSEMma_layout = cute.make_layout( + (self.m_block_size, self.n_block_size, self.num_stages_Q), + stride=(1, 0, cute.round_up(self.m_block_size, 64)), + ) + sLSEMma_layout_transposed = cute.make_layout( + (self.n_block_size, self.m_block_size, self.num_stages_Q), + stride=(0, 1, cute.round_up(self.m_block_size, 64)), + ) + self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + ) + # tQK_layout: thread layout for QK load + tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + tQK_layout = cute.make_ordered_layout( + (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + ) + # Do we need to check if we overshot kBlockM when we load Q? + self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0 + # Do we need to check if we overshot kBlockN when we load K? + self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0 + tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1" + tVdO_layout = cute.make_ordered_layout( + (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0), + ) + # Do we need to check if we overshot kBlockN when we load V? + self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0 + self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0 + + # Value layouts for copies + vQKVdO_layout = cute.make_layout((1, async_copy_elems)) + + # gmem_tiled_copy_QK: tiled copy for QK load + self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout) + self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout) + self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout) + self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout) + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_accum, + cute.make_layout(self.num_threads), + cute.make_layout(async_copy_elems_accum), + ) + self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width + ), + cute.make_layout(self.num_threads), + cute.make_layout(1) + ) + if cutlass.const_expr(self.qhead_per_kvhead > 1): + self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum + self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum + + def _get_tiled_mma(self): + num_mma_warps = self.num_threads // 32 + AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) + tiled_mma_sdp = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutSdP, + permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16), + ) + AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) + tiled_mma_dkv = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutdKV, + permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16), + ) + AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + tiled_mma_dq = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutdQ, + permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + ) + return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct, sdO_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + sLSE_struct, sdPsum_struct = [ + cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128] + for layout in (self.sLSE_layout, self.sLSE_layout) + ] + sP_struct, sdS_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128] + for layout in (self.sPdS_layout, self.sPdS_layout) + ] + + @cute.struct + class SharedStorageSeparateQV: + sK: sK_struct + sV: sV_struct + sQ: sQ_struct + sdO: sdO_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sP: sP_struct + sdS: sdS_struct + # TODO: the case where there's no sP + + @cute.struct + class SharedStorageSharedQV: + sK: sK_struct + sV: sV_struct + sQ: sQV_struct + sdO: sdO_struct + sLSE: sLSE_struct + sdPsum: sdPsum_struct + sP: sP_struct + sdS: sdS_struct + + return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccum: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + # Get the data type and check if it is fp16 or bf16 + self._check_type(*(t.element_type if t is not None else None + for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV))) + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() + # grid_dim: (n_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mK.shape[1], self.n_block_size), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[0]), + ) + softmax_scale_log2 = softmax_scale * math.log2(math.e) + self.kernel( + mQ, + mK, + mV, + mdO, + mLSE, + mdPsum, + mdQaccum, + mdK, + mdV, + softmax_scale, + softmax_scale_log2, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sdO_layout, + self.sPdS_layout, + self.sLSE_layout, + self.sLSEMma_layout, + self.gmem_tiled_copy_QK, + self.gmem_tiled_copy_VdO, + self.gmem_tiled_copy_dK, + self.gmem_tiled_copy_dV, + self.gmem_tiled_copy_LSE, + self.gmem_tiled_copy_dQaccum, + tiled_mma_sdp, + tiled_mma_dkv, + tiled_mma_dq, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mdO: cute.Tensor, + mLSE: cute.Tensor, + mdPsum: cute.Tensor, + mdQaccu: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + softmax_scale: cutlass.Float32, + softmax_scale_log2: cutlass.Float32, + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sdO_layout: cute.ComposedLayout, + sPdS_layout: cute.ComposedLayout, + sLSE_layout: cute.Layout, + sLSEMma_layout: cute.Layout, + gmem_tiled_copy_QK: cute.TiledCopy, + gmem_tiled_copy_VdO: cute.TiledCopy, + gmem_tiled_copy_dK: cute.TiledCopy, + gmem_tiled_copy_dV: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + gmem_tiled_copy_dQaccum: cute.TiledCopy, + tiled_mma_sdp: cute.TiledMma, + tiled_mma_dkv: cute.TiledMma, + tiled_mma_dq: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + n_block, head_idx, batch_idx = cute.arch.block_idx() + + m_block_max = cute.ceil_div(mQ.shape[1], self.m_block_size) + m_block_min = 0 + if cutlass.const_expr(self.is_causal): + m_block_min = max( + (n_block * self.n_block_size + mQ.shape[1] - mK.shape[1]) // self.m_block_size, + m_block_min, + ) + # TODO: return early if m_block_max == 0 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.m_block_size, self.head_dim_padded) + blkK_shape = (self.n_block_size, self.head_dim_padded) + blkV_shape = (self.n_block_size, self.head_dim_v_padded) + blkdO_shape = (self.m_block_size, self.head_dim_v_padded) + # (m_block_size, head_dim, m_block) + gQ = cute.local_tile(mQ[batch_idx, None, head_idx, None], blkQ_shape, (None, 0)) + # (n_block_size, head_dim) + head_idx_kv = head_idx // self.qhead_per_kvhead + gK = cute.local_tile(mK[batch_idx, None, head_idx_kv, None], blkK_shape, (n_block, 0)) + # (n_block_size, head_dim_v) + gV = cute.local_tile(mV[batch_idx, None, head_idx_kv, None], blkV_shape, (n_block, 0)) + # (m_block_size, head_dim_v, m_block) + gdO = cute.local_tile(mdO[batch_idx, None, head_idx, None], blkdO_shape, (None, 0)) + gLSE = cute.local_tile(mLSE[batch_idx, head_idx, None], (self.m_block_size,), (None,)) + gdPsum = cute.local_tile(mdPsum[batch_idx, head_idx, None], (self.m_block_size,), (None,)) + gdQaccum = cute.local_tile(mdQaccu[batch_idx, head_idx, None], (self.m_block_size * self.head_dim_padded,), (None,)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if cutlass.const_expr(not self.share_QV_smem): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + sdO = storage.sdO.get_tensor(sdO_layout) + sP = storage.sP.get_tensor(sPdS_layout) + sdS = storage.sdS.get_tensor(sPdS_layout) + sLSE = storage.sLSE.get_tensor(sLSE_layout) + sdPsum = storage.sdPsum.get_tensor(sLSE_layout) + sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) + sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) + + # Transpose view of tensors for tiled mma + sQt, sdOt, sKt, sPt, sdSt = [utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] + + gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) + gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) + gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tQgQ = gmem_thr_copy_QK.partition_S(gQ) + tQsQ = gmem_thr_copy_QK.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K) + tKgK = gmem_thr_copy_QK.partition_S(gK) + tKsK = gmem_thr_copy_QK.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K) + tVgV = gmem_thr_copy_VdO.partition_S(gV) + tVsV = gmem_thr_copy_VdO.partition_D(sV) + # (CPY_Atom, CPY_M, CPY_K, m_block) + tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) + tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) + tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) + tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) + tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) + tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) + tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) + thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) + thr_mma_dq = tiled_mma_dq.get_slice(tidx) + acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) + acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) + acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) + acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) + acc_dK.fill(0.0) + acc_dV.fill(0.0) + + tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) + tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) + tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) + tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) + tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) + tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) + + LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) + tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] + tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + ) + smem_copy_atom_transposed = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + ) + smem_thr_copy_QdO = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + smem_thr_copy_KV = utils.make_tiled_copy_B( + smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB + ).get_slice(tidx) + # TODO: should this be smem_copy_atom_transposed? + smem_thr_copy_PdSt = utils.make_tiled_copy_A( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_QdOt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB + ).get_slice(tidx) + smem_thr_copy_dS = utils.make_tiled_copy_A( + smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + smem_thr_copy_Kt = utils.make_tiled_copy_B( + smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB + ).get_slice(tidx) + # TODO: what's the number of bits? What if SdP_swapAB + r2s_thr_copy_PdS = cute.make_tiled_copy_C( + cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ), + tiled_mma_sdp, + ).get_slice(tidx) + + tSsQ = smem_thr_copy_QdO.partition_S(sQ) + tdPsdO = smem_thr_copy_QdO.partition_S(sdO) + tSsK = smem_thr_copy_KV.partition_S(sK) + tdPsV = smem_thr_copy_KV.partition_S(sV) + tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) + tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) + tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) + tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) + tdQsdS = smem_thr_copy_dS.partition_S(sdS) + tdQsKt = smem_thr_copy_Kt.partition_S(sKt) + tPsP = r2s_thr_copy_PdS.partition_D(sP) + tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQcQ = gmem_thr_copy_QK.partition_S(cQ) + t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdOcdO = tQcQ + t0dOcdO = t0QcQ + else: + cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) + t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) + cLSE = cute.make_identity_tensor((self.m_block_size,)) + tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) + + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[3]) + if cutlass.const_expr(self.same_hdim_kv): + tdOpdO = tQpQ + else: + tdOpdO = utils.predicate_k(tdOcdO, limit=mdO.shape[3]) + + # group parameters for compute_one_m_block + mma_params = SimpleNamespace( + thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, + tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, + tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, + tdQrdS=tdQrdS, tdQrK=tdQrK, + acc_dK=acc_dK, acc_dV=acc_dV, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_QdO=smem_thr_copy_QdO, + smem_thr_copy_KV=smem_thr_copy_KV, + smem_thr_copy_PdSt=smem_thr_copy_PdSt, + smem_thr_copy_QdOt=smem_thr_copy_QdOt, + smem_thr_copy_dS=smem_thr_copy_dS, + smem_thr_copy_Kt=smem_thr_copy_Kt, + r2s_thr_copy_PdS=r2s_thr_copy_PdS, + tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, + tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, + tPsP=tPsP, tdSsdS=tdSsdS, + tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, + tdQsdS=tdQsdS, tdQsKt=tdQsKt, + ) + gmem_copy_params = SimpleNamespace( + gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum + ) + seqlen = SeqlenInfoQK(batch_idx, mQ.shape[1], mK.shape[1]) + load_Q_LSE = partial( + self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, + tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, + tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + load_dO_dPsum = partial( + self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, + tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, + tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q + ) + compute_one_m_block = partial( + self.compute_one_m_block, mma_params=mma_params, + smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, + load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, + m_block_max=m_block_max, + softmax_scale_log2=softmax_scale_log2, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, + headdim=mV.shape[3]) + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_commit_group() + self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, + headdim=mK.shape[3]) + cute.arch.cp_async_commit_group() + + if cutlass.const_expr(self.V_in_regs): + cute.arch.cp_async_wait_group(1) + cute.arch.barrier() + tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) + cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) + # Sync to avoid loading Q to smem_q, which overlaps with smem_v + cute.arch.barrier() + + m_block = m_block_min + assert self.num_stages_Q >= self.num_stages_dO + for stage in cutlass.range_constexpr(self.num_stages_Q): + if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): + if stage == 0 or m_block + stage < m_block_max: + load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + if cutlass.const_expr(stage < self.num_stages_dO): + if stage == 0 or m_block + stage < m_block_max: + load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, + mask_seqlen=True, mask_causal=self.is_causal + ) + smem_pipe_read_q = cutlass.Int32(0) + smem_pipe_read_do = cutlass.Int32(0) + smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) + smem_pipe_write_do = cutlass.Int32(0) + for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): + compute_one_m_block( + m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, + mask_fn=mask_fn, + ) + smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) + smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) + smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) + smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # If GQA, we scale dK in the postprocessing kernel instead + if cutlass.const_expr(self.qhead_per_kvhead == 1): + acc_dK.store(acc_dK.load() * softmax_scale) + # reuse sK and sV data iterator + sdK = cute.make_tensor(sK.iterator, sK_layout) + sdV = cute.make_tensor(sV.iterator, sV_layout) + self.epilogue( + acc_dK, acc_dV, mdK, mdV, sdK, sdV, + gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, + tidx, n_block, head_idx, batch_idx + ) + + @cute.jit + def compute_one_m_block( + self, + m_block: cutlass.Int32, + smem_pipe_read_q: cutlass.Int32, + smem_pipe_read_do: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + smem_pipe_write_do: cutlass.Int32, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + gmem_copy_params: SimpleNamespace, + load_Q_LSE: Callable, + load_dO_dPsum: Callable, + m_block_max: cutlass.Int32, + softmax_scale_log2: cutlass.Float32, + mask_fn: Optional[Callable] = None, + ): + def load_Q_next(): + m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1) + if m_block_next < m_block_max: + load_Q_LSE(m_block_next, smem_pipe_write_q) + cute.arch.cp_async_commit_group() + + def load_dO_next(): + if m_block + self.num_stages_dO < m_block_max: + load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do) + cute.arch.cp_async_commit_group() + + # MMA S + acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C( + (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size) + ) + acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32) + acc_S.fill(0.0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0) + cute.arch.barrier() + sm80_utils.gemm( + mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK, + smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + smem_copy_params.tSsK, + smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, + swap_AB=self.SdP_swapAB, + ) + tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) + cute.autovec_copy( + smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE + ) + if cutlass.const_expr(mask_fn is not None): + mask_fn(acc_S, m_block=m_block) + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + bidx = 0 + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) + assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r])) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) + + # MMA dP + acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32) + acc_dP.fill(0.0) + cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0) + cute.arch.barrier() + sm80_utils.gemm( + mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV, + smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], + smem_copy_params.tdPsV, + smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, + hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None, + swap_AB=self.SdP_swapAB, + ) + tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0]) + cute.autovec_copy( + smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum + ) + acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) + assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) + for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): + acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + if cutlass.const_expr(not self.Mma_dKV_is_RS): + tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N) + cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP) + rdS = cute.make_fragment_like(acc_dP, self.dtype) + rdS.store(acc_dP.load().to(self.dtype)) + if cutlass.const_expr(not self.Mma_dKV_is_RS): + cute.arch.barrier() # Make sure P is written + # For hdim 64, It's faster to write to smem_dS first before the dV gemm + if cutlass.const_expr(not self.Mma_dKV_is_RS): + tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS) + cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS) + if cutlass.const_expr(self.Mma_dKV_is_RS): + tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) + else: + tdVrP = mma_params.tdVrP + + # MMA dK + sm80_utils.gemm( + mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO, + smem_copy_params.tdVsPt, + smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], + smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, + A_in_regs=self.Mma_dKV_is_RS, + swap_AB=self.dKV_swapAB, + ) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV) + cute.arch.barrier() # Make sure dS is written + + # MMA dQ + def dQ_mma(hook_fn): + acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C( + (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size) + ) + acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32) + acc_dQ.fill(0.0) + sm80_utils.gemm( + mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK, + smem_copy_params.tdQsdS, smem_copy_params.tdQsKt, + smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt, + swap_AB=self.dQ_swapAB, + hook_fn=hook_fn + ) + # ((1, 1), num_elements) + acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ) + tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] + assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(acc_dQ) + for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) + # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) + # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) + + # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration + if cutlass.const_expr(self.num_stages_Q > 1): + dQ_mma(load_dO_next) + + # MMA dK + if cutlass.const_expr(self.Mma_dKV_is_RS): + tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout)) + else: + tdKrdS = mma_params.tdKrdS + sm80_utils.gemm( + mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ, + smem_copy_params.tdKsdSt, + smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, + A_in_regs=self.Mma_dKV_is_RS, + swap_AB=self.dKV_swapAB, + hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None, + ) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK) + if cutlass.const_expr(self.num_stages_Q == 1): + cute.arch.barrier() + dQ_mma(load_Q_next) + + @cute.jit + def epilogue( + self, + acc_dK: cute.Tensor, + acc_dV: cute.Tensor, + mdK: cute.Tensor, + mdV: cute.Tensor, + sdK: cute.Tensor, + sdV: cute.Tensor, + gmem_tiled_copy_dK: cute.TiledCopy, + gmem_tiled_copy_dV: cute.TiledCopy, + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + n_block: cutlass.Int32, + num_head: cutlass.Int32, + batch_size: cutlass.Int32, + ): + rdV = cute.make_fragment_like(acc_dV, self.dtype) + rdV.store(acc_dV.load().to(self.dtype)) + rdK = cute.make_fragment_like(acc_dK, self.dtype) + rdK.store(acc_dK.load().to(self.dtype)) + gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) + gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) + + if cutlass.const_expr(self.qhead_per_kvhead == 1): + # Make sure all threads have finished reading K and V, otherwise we get racy dQ + # because smem_q could be changed. + cute.arch.barrier() + # smem copy atom for dKV + smem_copy_atom_dKV = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width + ) + smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) + taccdVrdV = smem_thr_copy_dKV.retile(rdV) + taccdKrdK = smem_thr_copy_dKV.retile(rdK) + taccdVsdV = smem_thr_copy_dKV.partition_D(sdV) + taccdKsdK = smem_thr_copy_dKV.partition_D(sdK) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) + cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) + + blkdK_shape = (self.n_block_size, self.head_dim_padded) + blkdV_shape = (self.n_block_size, self.head_dim_v_padded) + gdK = cute.local_tile(mdK[batch_size, None, num_head, None], blkdK_shape, (n_block, 0)) + gdV = cute.local_tile(mdV[batch_size, None, num_head, None], blkdV_shape, (n_block, 0)) + tdKsdK = gmem_thr_copy_dK.partition_S(sdK) + tdKgdK = gmem_thr_copy_dK.partition_D(gdK) + tdVsdV = gmem_thr_copy_dV.partition_S(sdV) + tdVgdV = gmem_thr_copy_dV.partition_D(gdV) + tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype) + tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype) + # sync before all smem stores are done. + cute.arch.barrier() + # load acc dK and dV from smem to rmem for wider vectorization + # Need to check OOB when reading from smem if kBlockN isn't evenly tiled + # TODO + cute.autovec_copy(tdKsdK, tdKrdK) + cute.autovec_copy(tdVsdV, tdVrdV) + + cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tdKcdK = gmem_thr_copy_dK.partition_S(cdK) + t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK) + if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): + tdVcdV = tdKcdK + t0dVcdV = t0dKcdK + else: + cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + tdVcdV = gmem_thr_copy_dV.partition_S(cdV) + t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) + tdKpdK = utils.predicate_k(tdKcdK, limit=mdK.shape[3]) + if cutlass.const_expr(self.same_hdim_kv): + tdVpdV = tdKpdK + else: + tdVpdV = utils.predicate_k(tdVcdV, limit=mdV.shape[3]) + # copy acc dK and acc_dV from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])): + if t0dKcdK[0, rest_m, 0][0] < mdK.shape[1] - n_block * self.n_block_size - tdKcdK[0][0]: + cute.copy( + gmem_tiled_copy_dK, + tdKrdK[None, rest_m, None], + tdKgdK[None, rest_m, None], + pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): + if t0dVcdV[0, rest_m, 0][0] < mdV.shape[1] - n_block * self.n_block_size - tdVcdV[0][0]: + cute.copy( + gmem_tiled_copy_dV, + tdVrdV[None, rest_m, None], + tdVgdV[None, rest_m, None], + pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, + ) + + else: # qhead_per_kvhead > 1, do atomic add + # For Sm90, we need to sync to avoid racy writes to smem_q + # For Sm80, we don't need to sync since we're not touching smem + num_head_kv = num_head // self.qhead_per_kvhead + gdV = cute.local_tile(mdV[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_v_padded,), (n_block,)) + gdK = cute.local_tile(mdK[batch_size, num_head_kv, None], (self.n_block_size * self.head_dim_padded,), (n_block,)) + tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV) + tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK) + acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV) + acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) + assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) + assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) + for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) + for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True): + utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) + + @cute.jit + def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr): + return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0 + + @cute.jit + def load_K( + self, + gmem_thr_copy: cute.TiledCopy, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + block: cutlass.Int32, + seqlen: cutlass.Int32, + headdim: cutlass.Int32, + ): + cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tKcK = gmem_thr_copy.partition_S(cK) + t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK) + tKpK = utils.predicate_k(tKcK, limit=headdim) + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size: + # Instead of using tKcK, we using t0KcK and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. + predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0] + predicate = cute.make_fragment_like(tKpK[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n + cute.copy( + gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate, + ) + # We need to clear the sK smem tiles since we'll use sKt for mma_dq + + @cute.jit + def load_V( + self, + gmem_thr_copy: cute.TiledCopy, + tVgV: cute.Tensor, + tVsV: cute.Tensor, + block: cutlass.Int32, + seqlen: cutlass.Int32, + headdim: cutlass.Int32, + ): + cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + tVcV = gmem_thr_copy.partition_S(cV) + t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV) + tVpV = utils.predicate_k(tVcV, limit=headdim) + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: + # Instead of using tVcV, we using t0VcV and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time. + predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0] + predicate = cute.make_fragment_like(tVpV[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n + cute.copy( + gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate, + ) + + @cute.jit + def load_Q_LSE( + self, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, + tQcQ: cute.Tensor, + t0QcQ: cute.Tensor, + tQpQ: cute.Tensor, + tLSEgLSE: cute.Tensor, + tLSEsLSE: cute.Tensor, + tLSEcLSE: cute.Tensor, + block: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + seqlen: cutlass.Int32, + ): + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size: + # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. + predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0] + predicate = cute.make_fragment_like(tQpQ[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m + cute.copy( + gmem_tiled_copy_Q, + tQgQ[None, m, None, block], + tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0], + pred=predicate, + ) + # We need to clear the sQ smem tiles since we'll use sQt for mma_dK + # We made sure LSE length is padded so we read `kBlockM` elements so that all + # elements in sLSE are filled. Without this we might have uninitialized sLSE values. + for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])): + if tLSEcLSE[0, m][0] < self.m_block_size: + cute.copy( + gmem_tiled_copy_LSE, + tLSEgLSE[None, m, block], + tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], + ) + + @cute.jit + def load_dO_dPsum( + self, + gmem_tiled_copy_dO: cute.TiledCopy, + gmem_tiled_copy_dPsum: cute.TiledCopy, + tdOgdO: cute.Tensor, + tdOsdO: cute.Tensor, + tdOcdO: cute.Tensor, + t0dOcdO: cute.Tensor, + tdOpdO: cute.Tensor, + tdPsumgdPsum: cute.Tensor, + tdPsumsdPsum: cute.Tensor, + tdPsumcdPsum: cute.Tensor, + block: cutlass.Int32, + smem_pipe_write_q: cutlass.Int32, + seqlen: cutlass.Int32, + ): + for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])): + # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked + if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size: + # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time. + predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0] + predicate = cute.make_fragment_like(tdOpdO[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m + cute.copy( + gmem_tiled_copy_dO, + tdOgdO[None, m, None, block], + tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], + pred=predicate, + ) + # We need to clear the sQ smem tiles since we'll use sQt for mma_dK + # We made sure LSE length is padded so we read `kBlockM` elements so that all + # elements in sLSE are filled. Without this we might have uninitialized sLSE values. + for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])): + if tdPsumcdPsum[0, m][0] < self.m_block_size: + cute.copy( + gmem_tiled_copy_dPsum, + tdPsumgdPsum[None, m, block], + tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], + ) diff --git a/flash_attn/cute/flash_bwd_postprocess.py b/flash_attn/cute/flash_bwd_postprocess.py new file mode 100644 index 0000000000..6a408906d5 --- /dev/null +++ b/flash_attn/cute/flash_bwd_postprocess.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +from typing import Type + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warp + +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import utils + + +class FlashAttentionBackwardPostprocess: + def __init__( + self, + dtype: Type[cutlass.Numeric], + # tiled_mma: cute.TiledMma, + head_dim: int, + m_block_size: int = 128, + num_threads: int = 256, + AtomLayoutMdQ: int = 1, + dQ_swapAB: bool = False, + ): + """Initializes the configuration for a flash attention v2 kernel. + + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + """ + self.dtype = dtype + self.m_block_size = m_block_size + # padding head_dim to a multiple of 32 as k_block_size + hdim_multiple_of = 32 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.head_dim_padded + # self.tiled_mma = tiled_mma + self.num_threads = num_threads + self.AtomLayoutMdQ = AtomLayoutMdQ + self.dQ_swapAB = dQ_swapAB + + @staticmethod + def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + return True + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_async_copy_accum = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + # We don't do bound checking for the gmem -> smem load so we just assert here. + assert ( + self.m_block_size * self.head_dim_padded // async_copy_elems_accum + ) % self.tiled_mma.size == 0 + self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + atom_async_copy_accum, + cute.make_layout(self.tiled_mma.size), + cute.make_layout(async_copy_elems_accum), + ) + atom_universal_copy_accum = cute.make_copy_atom( + # multiply by 4 for Sm90 + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=cutlass.Float32.width, + ) + self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + atom_universal_copy_accum, + cute.make_layout(self.tiled_mma.size), + cute.make_layout(1), # 4 for Sm90 + ) + + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_universal_copy: universal copy atom for dQ store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tdQ_layout: thread layout for dQ store + assert self.head_dim_padded % async_copy_elems == 0 + gmem_threads_per_row = math.gcd( + self.head_dim_padded // async_copy_elems, self.tiled_mma.size + ) + assert self.tiled_mma.size % gmem_threads_per_row == 0 + tdQ_layout = cute.make_ordered_layout( + (self.tiled_mma.size // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + # Value layouts for copies + vdQ_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_dQ = cute.make_tiled_copy_tv( + atom_universal_copy, tdQ_layout, vdQ_layout + ) + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: dQaccum / dQ + # /////////////////////////////////////////////////////////////////////////////// + self.sdQaccum_layout = cute.make_layout(self.m_block_size * self.head_dim_padded) + # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, + # then setting kBlockKSmem to 32 will cause "Static shape_div failure". + # We want to treat it as 64 x 48, so kBlockKSmem should be 16. + mma_shape_n = self.tiled_mma.get_tile_size(1) + sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) + self.sdQ_layout = cute.tile_to_shape( + sdQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1) + ) + + @cute.jit + def __call__( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + stream: cuda.CUstream, + ): + # Get the data type and check if it is fp16 or bf16 + if cutlass.const_expr(not mdQ.element_type in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + + num_mma_warps = self.num_threads // 32 + AtomLayoutdQ = ( + (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) + if cutlass.const_expr(not self.dQ_swapAB) + else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) + ) + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), + AtomLayoutdQ, + permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), + ) + self.tiled_mma = tiled_mma + + self._setup_attributes() + + smem_size = max( + cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), + cute.size_in_bytes(self.dtype, self.sdQ_layout), + ) + + # grid_dim: (m_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mdQ.shape[1], self.m_block_size), + cute.size(mdQ.shape[2]), + cute.size(mdQ.shape[0]), + ) + self.kernel( + mdQaccum, + mdQ, + scale, + tiled_mma, + self.dQ_swapAB, + self.sdQaccum_layout, + self.sdQ_layout, + self.g2s_tiled_copy_dQaccum, + self.s2r_tiled_copy_dQaccum, + self.gmem_tiled_copy_dQ, + ).launch( + grid=grid_dim, + block=[tiled_mma.size, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdQaccum: cute.Tensor, + mdQ: cute.Tensor, + scale: cutlass.Float32, + tiled_mma: cute.TiledMma, + dQ_swapAB: cutlass.Constexpr, + sdQaccum_layout: cute.Layout, + sdQ_layout: cute.ComposedLayout, + g2s_tiled_copy_dQaccum: cute.TiledCopy, + s2r_tiled_copy_dQaccum: cute.TiledCopy, + gmem_tiled_copy_dQ: cute.TiledCopy, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, num_head, batch_size = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) + blkdQ_shape = (self.m_block_size, self.head_dim_padded) + gdQ = cute.local_tile(mdQ[batch_size, None, num_head, None], blkdQ_shape, (m_block, 0)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) + sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) + + seqlen_q = mdQ.shape[1] + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + + # Step 1: load dQaccum from gmem to smem + g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) + tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) + tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum) + # print(tdQgdQaccum) + # print(tdQsdQaccum) + cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # Step 2: load dQ from smem to rmem + s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) + tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) + # print(s2r_tiled_copy_dQaccum) + # print(sdQaccum) + # thr_mma = tiled_mma.get_slice(tidx) + # print(tiled_mma) + acc_shape = tiled_mma.partition_shape_C( + (self.m_block_size, self.head_dim_padded) + if cutlass.const_expr(not dQ_swapAB) + else (self.head_dim_padded, self.m_block_size) + ) + acc = cute.make_fragment(acc_shape, cutlass.Float32) + assert cute.size(acc) == cute.size(tdQsdQaccum) + tdQrdQaccum = s2r_thr_copy_dQaccum.retile(acc) + # Somehow even after retiling the layouts of tdQsdQaccum and tdQrdQaccum are different. + # So we have to do a for loop to copy + # cute.copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum) + # print(acc) + # print(tdQsdQaccum) # ((1, 1), 64) + # print(tdQrdQaccum) # ((1, 4), 4, 4) + for i in cutlass.range(cute.size(tdQsdQaccum), unroll_full=True): + tdQrdQaccum[i] = tdQsdQaccum[i] + # Convert tdQrdQaccum from fp32 to fp16/bf16 + rdQ = cute.make_fragment_like(acc, self.dtype) + rdQ.store((acc.load() * scale).to(self.dtype)) + + # Step 3: Copy dQ from register to smem + cute.arch.barrier() # make sure all threads have finished loading dQaccum + smem_copy_atom_dQ = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=cutlass.Float32.width + ) + smem_thr_copy_dQ = cute.make_tiled_copy_C(smem_copy_atom_dQ, tiled_mma).get_slice(tidx) + taccdQrdQ = smem_thr_copy_dQ.retile(rdQ) + taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ) + cute.copy(smem_copy_atom_dQ, taccdQrdQ, taccdQsdQ) + # print(taccdQrdQ) + # print(taccdQsdQ) + + # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem + gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) + tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) + tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) + tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) + cute.arch.barrier() # make sure all smem stores are done + # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled + cute.autovec_copy(tdQsdQ, tdQrdQ) + + # Step 5: Copy dQ from register to gmem + cdQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) + tdQpdQ = utils.predicate_k(tdQcdQ, limit=mdQ.shape[3]) + for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): + if tdQcdQ[0, rest_m, 0][0] < mdQ.shape[1] - m_block * self.m_block_size: + cute.copy( + gmem_tiled_copy_dQ, + tdQrdQ[None, rest_m, None], + tdQgdQ[None, rest_m, None], + pred=tdQpdQ[None, rest_m, None], + ) diff --git a/flash_attn/cute/flash_bwd_preprocess.py b/flash_attn/cute/flash_bwd_preprocess.py new file mode 100644 index 0000000000..a5da7b7009 --- /dev/null +++ b/flash_attn/cute/flash_bwd_preprocess.py @@ -0,0 +1,288 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Type, Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute + +from flash_attn.cute import utils + + +class FlashAttentionBackwardPreprocess: + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 128, + num_threads: int = 128, + ): + """ + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param num_threads: number of threads + :type num_threads: int + """ + self.dtype = dtype + self.m_block_size = m_block_size + # padding head_dim to a multiple of 32 as k_block_size + hdim_multiple_of = 32 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + self.check_hdim_oob = head_dim != self.head_dim_padded + self.num_threads = num_threads + + @staticmethod + def can_implement(dtype, head_dim, m_block_size, num_threads) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param num_threads: number of threads + :type num_threads: int + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if num_threads < m_block_size: # For multiplying lse with log2 + return False + return True + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + # We want kBlockKGmem to be a power of 2 so that when we do the summing, + # it's just between threads in the same warp + gmem_k_block_size = ( + 128 + if self.head_dim_padded % 128 == 0 + else ( + 64 + if self.head_dim_padded % 64 == 0 + else (32 if self.head_dim_padded % 32 == 0 else 16) + ) + ) + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_universal_copy: universal copy atom for O & dO load + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tOdO_layout: thread layout for O & dO load + self.gmem_threads_per_row = gmem_k_block_size // async_copy_elems + assert self.num_threads % self.gmem_threads_per_row == 0 + tOdO_layout = cute.make_ordered_layout( + (self.num_threads // self.gmem_threads_per_row, self.gmem_threads_per_row), + order=(1, 0), + ) + # Value layouts for copies + vOdO_layout = cute.make_layout((1, async_copy_elems)) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOdO_layout, vOdO_layout + ) + self.gmem_tiled_copy_dO = cute.make_tiled_copy_tv( + atom_universal_copy, tOdO_layout, vOdO_layout + ) + + async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width + atom_universal_copy_accum = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=universal_copy_bits, + ) + assert ( + self.m_block_size * self.head_dim_padded // async_copy_elems_accum + ) % self.num_threads == 0 + self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( + atom_universal_copy_accum, + cute.make_layout(self.num_threads), + cute.make_layout(async_copy_elems_accum), + ) + + @cute.jit + def __call__( + self, + mO: cute.Tensor, + mdO: cute.Tensor, + mdPsum: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSElog2: Optional[cute.Tensor], + mdQaccum: Optional[cute.Tensor], + stream: cuda.CUstream, + ): + # Get the data type and check if it is fp16 or bf16 + if cutlass.const_expr(not (mO.element_type == mdO.element_type)): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr(not mO.element_type in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if cutlass.const_expr(not mdPsum.element_type in [cutlass.Float32]): + raise TypeError("dPsum tensor must be Float32") + if cutlass.const_expr(mdQaccum is not None): + if cutlass.const_expr(not mdQaccum.element_type in [cutlass.Float32]): + raise TypeError("dQaccum tensor must be Float32") + if cutlass.const_expr(mLSE is not None): + assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" + if cutlass.const_expr(not mLSE.element_type in [cutlass.Float32]): + raise TypeError("LSE tensor must be Float32") + if cutlass.const_expr(not mLSElog2.element_type in [cutlass.Float32]): + raise TypeError("LSElog2 tensor must be Float32") + + self._setup_attributes() + + # grid_dim: (m_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mO.shape[1], self.m_block_size), + cute.size(mO.shape[2]), + cute.size(mO.shape[0]), + ) + self.kernel( + mO, + mdO, + mdPsum, + mLSE, + mLSElog2, + mdQaccum, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_dO, + self.gmem_tiled_copy_dQaccum, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO: cute.Tensor, + mdO: cute.Tensor, + mdPsum: cute.Tensor, + mLSE: Optional[cute.Tensor], + mLSElog2: Optional[cute.Tensor], + mdQaccum: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_dO: cute.TiledCopy, + gmem_tiled_copy_dQaccum: cute.TiledCopy, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, num_head, batch_size = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkOdO_shape = (self.m_block_size, self.head_dim_padded) + # (m_block_size, head_dim) + gO = cute.local_tile(mO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + gdO = cute.local_tile(mdO[batch_size, None, num_head, None], blkOdO_shape, (m_block, 0)) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + gmem_thr_copy_dO = gmem_tiled_copy_dO.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tOgO = gmem_thr_copy_O.partition_S(gO) + tOgdO = gmem_thr_copy_dO.partition_S(gdO) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cOdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy_O.partition_S(cOdO) + t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cOdO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[3]) + tOcdO = gmem_thr_copy_dO.partition_S(cOdO) + t0OcdO = gmem_thr_copy_dO.get_slice(0).partition_S(cOdO) + tOpdO = utils.predicate_k(tOcdO, limit=mdO.shape[3]) + + seqlen_q = mO.shape[1] + seqlen_q_rounded = cute.round_up(seqlen_q, self.m_block_size) + + if cutlass.const_expr(mLSE is not None): + gLSE = cute.local_tile( + mLSE[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) + lse = cutlass.Float32.inf + if tidx < seqlen_q - m_block * self.m_block_size: + lse = gLSE[tidx] + + tOrO = cute.make_fragment_like(tOgO) + tOrdO = cute.make_fragment_like(tOgdO) + assert cute.size(tOgO, mode=[0]) == cute.size(tOgdO, mode=[0]) + assert cute.size(tOgO, mode=[1]) == cute.size(tOgdO, mode=[1]) + assert cute.size(tOgO, mode=[2]) == cute.size(tOgdO, mode=[2]) + for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): + # Instead of using tOcO, we using t0OcO and subtract the offset from the limit + # (seqlen_q - m_block * kBlockM). This is because the entries of t0OcO are known at compile time. + if t0OcO[0, m, 0][0] < seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_thr_copy_O, + tOgO[None, m, None], + tOrO[None, m, None], + pred=tOpO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + cute.copy( + gmem_thr_copy_dO, + tOgdO[None, m, None], + tOrdO[None, m, None], + pred=tOpdO[None, m, None] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + # Sum across the "k" dimension + dpsum = (tOrO.load().to(cutlass.Float32) * tOrdO.load().to(cutlass.Float32)).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) + ) + dpsum = utils.warp_reduce(dpsum, operator.add, width=self.gmem_threads_per_row) + dP_sum = cute.make_fragment(cute.size(tOrO, mode=[1]), cutlass.Float32) + dP_sum.store(dpsum) + + # Write dPsum from rmem -> gmem + gdPsum = cute.local_tile( + mdPsum[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) + # Only the thread corresponding to column 0 writes out the lse to gmem + if tOcO[0, 0, 0][1] == 0: + for m in cutlass.range(cute.size(dP_sum), unroll_full=True): + row = tOcO[0, m, 0][0] + gdPsum[row] = dP_sum[m] if row < mO.shape[1] - m_block * self.m_block_size else 0.0 + + # Clear dQaccum + if cutlass.const_expr(mdQaccum is not None): + blkdQaccum_shape = (self.m_block_size * self.head_dim_padded,) + gdQaccum = cute.local_tile( + mdQaccum[batch_size, num_head, None], blkdQaccum_shape, (m_block,) + ) + gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) + tQgQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) + zero = cute.make_fragment_like(tQgQaccum) + zero.fill(0.0) + cute.copy(gmem_tiled_copy_dQaccum, zero, tQgQaccum) + + if cutlass.const_expr(mLSE is not None): + gLSElog2 = cute.local_tile( + mLSElog2[batch_size, num_head, None], (self.m_block_size,), (m_block,) + ) + LOG2_E = math.log2(math.e) + if tidx < seqlen_q_rounded - m_block * self.m_block_size: + gLSElog2[tidx] = lse * LOG2_E if lse != -cutlass.Float32.inf else 0.0 diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py new file mode 100644 index 0000000000..d1b307acf0 --- /dev/null +++ b/flash_attn/cute/flash_fwd.py @@ -0,0 +1,1932 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of +# https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h +# and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h +# from Cutlass C++ to Cute-DSL. +# Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py + +import math +from types import SimpleNamespace +from typing import Type, Callable, Optional, Tuple +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +import cutlass.utils.ampere_helpers as sm80_utils_basic +import cutlass.utils.hopper_helpers as sm90_utils_basic + +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute import hopper_helpers as sm90_utils +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase + + +class FlashAttentionForwardBase: + + arch: int = 80 + + def __init__( + self, + dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: int = 1, + is_causal: bool = False, + is_local: bool = False, + pack_gqa: bool = True, + m_block_size: int = 128, + n_block_size: int = 128, + num_stages: int = 1, + num_threads: int = 128, + Q_in_regs: bool = False, + ): + """Initializes the configuration for a flash attention kernel. + + All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + """ + self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + # Can save registers (and hence be faster) if we don't have to check hdim predication + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.qhead_per_kvhead = qhead_per_kvhead + self.is_causal = is_causal + self.is_local = is_local + self.pack_gqa = pack_gqa + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.num_threads = num_threads + self.num_stages = num_stages + self.Q_in_regs = Q_in_regs + + @staticmethod + def can_implement( + dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages, num_threads, is_causal, + Q_in_regs=False + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + if dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if head_dim % 8 != 0: + return False + if head_dim_v % 8 != 0: + return False + if n_block_size % 16 != 0: + return False + if num_threads % 32 != 0: + return False + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage_Q = m_block_size * head_dim * 2 + smem_usage_K = n_block_size * head_dim * num_stages * 2 + smem_usage_V = n_block_size * head_dim_v * num_stages * 2 + smem_usage_QV = (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) + smem_usage = smem_usage_QV + smem_usage_K + # TODO: sm86 and sm89 + smem_capacity = sm80_utils_basic.SMEM_CAPACITY["sm80"] + if smem_usage > smem_capacity: + return False + # Check if twice the block size is divisible by the number of threads + if (m_block_size * 2) % num_threads != 0: + return False + return True + + def _check_type( + self, + mQ_type: Type[cutlass.Numeric], + mK_type: Type[cutlass.Numeric], + mV_type: Type[cutlass.Numeric], + mO_type: Type[cutlass.Numeric], + mLSE_type: Type[cutlass.Numeric] | None, + mCuSeqlensQ_type: Type[cutlass.Numeric] | None, + mCuSeqlensK_type: Type[cutlass.Numeric] | None, + mSeqUsedQ_type: Type[cutlass.Numeric] | None, + mSeqUsedK_type: Type[cutlass.Numeric] | None, + ): + # Get the data type and check if it is fp16 or bf16 + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + raise TypeError("All tensors must have the same data type") + if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): + raise TypeError("Only Float16 or BFloat16 is supported") + if const_expr(mLSE_type not in [None, Float32]): + raise TypeError("LSE tensor must be Float32") + if const_expr(mCuSeqlensQ_type not in [None, Int32]): + raise TypeError("cu_seqlens_q tensor must be Int32") + if const_expr(mCuSeqlensK_type not in [None, Int32]): + raise TypeError("cu_seqlens_k tensor must be Int32") + if const_expr(mSeqUsedQ_type not in [None, Int32]): + raise TypeError("seqused_q tensor must be Int32") + if const_expr(mSeqUsedK_type not in [None, Int32]): + raise TypeError("seqused_k tensor must be Int32") + assert mQ_type == self.dtype + + def _setup_attributes(self): + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = self._get_smem_layout_atom() + self.sQ_layout = cute.tile_to_shape( + sQ_layout_atom, (self.m_block_size, self.head_dim_padded), (0, 1), + ) + self.sK_layout = cute.tile_to_shape( + sK_layout_atom, (self.n_block_size, self.head_dim_padded, self.num_stages), (0, 1, 2), + ) + self.sV_layout = cute.tile_to_shape( + sV_layout_atom, (self.n_block_size, self.head_dim_v_padded, self.num_stages), (0, 1, 2), + ) + self.sO_layout = cute.tile_to_shape( + sO_layout_atom, (self.m_block_size, self.head_dim_v_padded), (0, 1), + ) + if const_expr(sP_layout_atom is not None): + self.sP_layout = cute.tile_to_shape( + sP_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), + ) + else: + self.sP_layout = None + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, + ) + # tQ_layout and tK_layout: thread layout for QK load + tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + assert self.num_producer_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" + tQ_layout = cute.make_ordered_layout( + (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + ) + tK_layout = cute.make_ordered_layout( + (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we load Q + assert self.m_block_size % tQ_layout.shape[0] == 0 + tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems + tV_layout = cute.make_ordered_layout( + (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + ) + # TODO: need a different layout for O if O dtype is not the same as V dtype + # tO_layout: thread layout for O store + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout) + self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout) + self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout) + # gmem_tiled_copy_O: tiled copy for O store + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + def _get_smem_layout_atom(self): + raise NotImplementedError() + + def _get_tiled_mma(self): + raise NotImplementedError() + + def _get_shared_storage_cls(self): + raise NotImplementedError() + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + softcap: Float32, + stream: cuda.CUstream, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + raise NotImplementedError() + + @cute.jit + def epilogue( + self, + acc_O: cute.Tensor, + lse: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sO: cute.Tensor, + seqlen: SeqlenInfoQK, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tiled_mma: cute.TiledMma, + tidx: Int32, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + ): + # store acc_O + rO = cute.make_fragment_like(acc_O, self.dtype) + rO.store(acc_O.load().to(self.dtype)) + # Make sure all threads have finished reading V + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_O, taccOrO, taccOsO) + + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) + + # Write LSE from rmem -> gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + if const_expr(not self.pack_gqa): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,)) + gLSE_expanded_layout = cute.append( + gLSE.layout, cute.make_layout((self.head_dim_v_padded,), stride=(0,)) + ) + gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) + thr_mma = tiled_mma.get_slice(tidx) + taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded)) + assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse) + taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO)) + t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO)) + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0: + for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): + if t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - taccOcO[0][0]: + taccOgLSE[m, 0] = lse[m] + else: + pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) + + if const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + # thr_mma = tiled_mma.get_slice(tidx) + # taccOgO = thr_mma.partition_C(gO) + # cute.autovec_copy(rO, taccOgO) + # sync to make sure all smem stores are done + if const_expr(self.use_tma_O): + # ensure smem writes are visible to TMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.barrier_arrive(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) + tOsO, tOgO = cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE) + cute.copy(tma_atom_O, tOsO, tOgO) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + # load acc O from smem to rmem for wider vectorization + cute.autovec_copy(tOsO, tOrO) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (m_block, 0)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) + + @cute.jit + def advance_pipeline(self, pipeline_index): + return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0 + + @cute.jit + def load_Q( + self, + gmem_thr_copy: cute.TiledCopy, + gQ: cute.Tensor, + sQ: cute.Tensor, + block: Int32, + seqlen: Int32, + headdim: Int32, + ): + tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=headdim) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit + # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. + if t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]: + cute.copy( + gmem_thr_copy, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def load_K( + self, + gmem_tiled_copy: cute.TiledCopy, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + tKcK: cute.Tensor, + t0KcK: cute.Tensor, + tKpK: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load K? + is_even_n_smem_k = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_k): + # Instead of using tKcK, we using t0KcK and subtract the offset from the limit + # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. + if const_expr(is_even_n_smem_k): + seqlen_limit = seqlen - block * self.n_block_size + else: + if const_expr(not need_predicates): + seqlen_limit = self.n_block_size + else: + seqlen_limit = cutlass.min(seqlen - block * self.n_block_size, self.n_block_size) + seqlen_limit -= tKcK[0][0] + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if t0KcK[0, n, 0][0] < seqlen_limit: + cute.copy( + gmem_tiled_copy, + tKgK[None, n, None, block], + tKsK[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + else: + cute.copy( + gmem_tiled_copy, + tKgK[None, None, None, block], + tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tKpK if const_expr(self.check_hdim_oob) else None, + ) + + @cute.jit + def load_V( + self, + gmem_tiled_copy: cute.TiledCopy, + tVgV: cute.Tensor, + tVsV: cute.Tensor, + tVcV: cute.Tensor, + t0VcV: cute.Tensor, + tVpV: cute.Tensor, + block: Int32, + smem_pipe_write: Int32, + seqlen: Int32, + need_predicates: cutlass.Constexpr, + ): + # Do we need to check if we overshoot kBlockN when we load V? + is_even_n_smem_v = self.n_block_size % gmem_tiled_copy.tiler_mn[0].shape == 0 + if const_expr(need_predicates or not is_even_n_smem_v): + for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): + # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked + if is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: + predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None + if const_expr(need_predicates): + seqlen_limit = seqlen - block * self.n_block_size - tVcV[0][0] + predicate_n = t0VcV[0, n, 0][0] < seqlen_limit + predicate = cute.make_fragment_like(tVpV[None, 0, None]) + for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): + for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): + predicate[i, k] = (tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True) and predicate_n + cute.copy( + gmem_tiled_copy, + tVgV[None, n, None, block], + tVsV[None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=predicate, + ) + else: + cute.copy( + gmem_tiled_copy, + tVgV[None, None, None, block], + tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], + pred=tVpV if const_expr(self.check_hdim_v_oob) else None, + ) + + +class FlashAttentionForwardSm80(FlashAttentionForwardBase): + + def _get_smem_layout_atom(self): + sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) + sO_layout_atom = sV_layout_atom + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + tiled_mma_pv = cute.make_tiled_mma( + warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), + (self.num_threads // 32, 1, 1), + permutation_mnk=(self.num_threads // 32 * 16, 16, 16), + ) + return tiled_mma_qk, tiled_mma_pv + + def _get_shared_storage_cls(self): + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] + for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + + @cute.struct + class SharedStorageQKV: + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + + @cute.struct + class SharedStorageSharedQV: + sQ: sQV_struct + sK: sK_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + stream: cuda.CUstream, + softmax_scale: Optional[Float32] = None, + softcap: Optional[Float32] = None, + window_size_left: Optional[Int32] = None, + window_size_right: Optional[Int32] = None, + learnable_sink: Optional[cute.Tensor] = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + assert learnable_sink is None, "Learnable sink is not supported in this kernel" + self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE))) + tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_pv.size + self.num_producer_threads = self.num_threads + self.num_Q_load_threads = self.num_threads + self.num_epilogue_threads = self.num_threads + # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None + self.use_tma_O = self.arch >= 90 + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.select(t.layout, mode=[1, 3, 2, 0])) for t in (mQ, mK, mV, mO)] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=[2, 1, 0])) + # grid_dim: (m_block, num_head, batch_size) + grid_dim = ( + cute.ceil_div(mQ.shape[0], self.m_block_size), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]), + ) + # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + # Right after this, we multiply by log2(e) before applying exp2. + # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) + # (assigning it to softmax_scale_log2). + LOG2_E = math.log2(math.e) + if const_expr(softcap is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softcap_val = None + else: + softmax_scale_log2 = softcap * LOG2_E + softcap_val = Float32(softmax_scale / softcap) + + self.kernel( + mQ, + mK, + mV, + mO, + mLSE, + softmax_scale_log2, + softcap_val, + window_size_left, + window_size_right, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, num_head, batch_size = cute.arch.block_idx() + + block_info = BlockInfo( + self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + seqlen = SeqlenInfoQK(seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0]) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # TODO: return early if n_block_max == 0 + # if self.is_causal: + # if n_block_max <= 0: + # return + n_block = n_block_max - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + blkQ_shape = (self.m_block_size, self.head_dim_padded) + blkK_shape = (self.n_block_size, self.head_dim_padded) + blkV_shape = (self.n_block_size, self.head_dim_v_padded) + gQ = cute.local_tile(mQ[None, None, num_head, batch_size], blkQ_shape, (m_block, 0)) + num_head_kv = num_head // self.qhead_per_kvhead + gK = cute.local_tile(mK[None, None, num_head_kv, batch_size], blkK_shape, (None, 0)) + gV = cute.local_tile(mV[None, None, num_head_kv, batch_size], blkV_shape, (None, 0)) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sK_layout) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout) + else: + sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) + # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + sVt = utils.transpose_view(sV) + + gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) + gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + thr_mma_pv = tiled_mma_pv.get_slice(tidx) + tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) + tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) + tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) + acc_shape_O = thr_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_QK = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, + ) + smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) + smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for KV + cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) + tKcK = gmem_thr_copy_K.partition_S(cK) + t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) + if const_expr(self.head_dim_padded == self.head_dim_v_padded): + tVcV = tKcK + t0VcV = t0KcK + else: + cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) + tVcV = gmem_thr_copy_V.partition_S(cV) + t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and + # use "if" on the mn dimension. + # This is to reduce register pressure and gets 2-3% performance gain. + tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) + if const_expr(self.same_hdim_kv): + tVpV = tKpK + else: + tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) + + # shape: (atom_v_m * rest_m) + softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + softmax.reset() + + # group parameters for compute_one_n_block + mma_params = SimpleNamespace( + thr_mma_qk=thr_mma_qk, thr_mma_pv=thr_mma_pv, + tSrQ=tSrQ, tSrK=tSrK, tOrVt=tOrVt, acc_O=acc_O, + ) + smem_copy_params = SimpleNamespace( + smem_thr_copy_Q=smem_thr_copy_Q, + smem_thr_copy_K=smem_thr_copy_K, + smem_thr_copy_V=smem_thr_copy_V, + tSsQ=tSsQ, tSsK=tSsK, tOsVt=tOsVt, + ) + load_K = partial(self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, + seqlen=seqlen.seqlen_k) + load_V = partial(self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, + seqlen=seqlen.seqlen_k) + # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + # -inf to e.g. -50.0, which can affect the attention softmax. + def scoremod_premask_fn(acc_S): + if const_expr(softcap_val is not None): + acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) + + compute_one_n_block = partial( + self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, + softmax=softmax, load_K=load_K, load_V=load_V, scoremod_premask_fn=scoremod_premask_fn, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1]) + cute.arch.cp_async_commit_group() + + def preprocess_Q(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) + if const_expr(self.Q_in_regs): + cute.arch.barrier() + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) + + # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and + # read from smem_q to registers, then load V. + # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. + if const_expr(self.Q_in_regs): + load_K(n_block, smem_pipe_write=0, need_predicates=True) + cute.arch.cp_async_commit_group() + preprocess_Q() + cute.arch.barrier() # Make sure all threads have read smem_q before loading V + + for stage in cutlass.range_constexpr(self.num_stages): + if const_expr(not self.Q_in_regs or stage > 0): + if stage == 0 or n_block - stage >= 0: + load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + cute.arch.cp_async_commit_group() + if const_expr(stage < self.num_stages - 1): + if stage == 0 or n_block - stage >= 0: + load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage==0) + cute.arch.cp_async_commit_group() + if const_expr(not self.Q_in_regs): + preprocess_Q() + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # /////////////////////////////////////////////////////////////////////////////// + # Start processing of the first n-block. + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last several blocks. + mask = AttentionMask( + self.m_block_size, self.n_block_size, seqlen.seqlen_q, seqlen.seqlen_k, + window_size_left, window_size_right, + self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + mask_fn = partial( + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, + mask_causal=self.is_causal, mask_local=self.is_local, + ) + + # First iteration with seqlen masking + smem_pipe_read = Int32(0) + smem_pipe_write = Int32(self.num_stages - 1) + compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 2 - n_tile + compute_one_n_block(n_block, smem_pipe_read, smem_pipe_write, check_inf=True, + mask_fn=partial(mask_fn, mask_seqlen=False)) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # The remaining iterations have no masking + for n_tile in cutlass.range(n_block, unroll=1): + compute_one_n_block(n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, check_inf=True) + smem_pipe_read = self.advance_pipeline(smem_pipe_read) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + # TODO: local + + # normalize acc_O by row_sum and calculate the lse + row_scale = softmax.finalize() + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # reuse sQ's data iterator + sO = cute.make_tensor(sQ.iterator, sO_layout) + self.epilogue( + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, + gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size + ) + + @cute.jit + def compute_one_n_block( + self, + n_block: Int32, + smem_pipe_read: Int32, + smem_pipe_write: Int32, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + load_K: Callable, + load_V: Callable, + scoremod_premask_fn: Callable, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + ): + """Compute one n_block of S/O. + + This function provides different variants for processing the first n block versus + subsequent blocks. + """ + def sync(): + cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) + cute.arch.barrier() + + acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) + acc_S = cute.make_fragment(acc_shape_S, Float32) + acc_S.fill(0.0) + # wait for smem tile QK before mma calculation for S + sync() + # need predicates for the first tile + def load_V_next(): + if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: + load_V(n_block - self.num_stages + 1, smem_pipe_write, + need_predicates=is_first_n_block and self.num_stages == 1) + cute.arch.cp_async_commit_group() + load_V_next() + sm80_utils.gemm( + mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, + smem_copy_params.tSsQ, + smem_copy_params.tSsK[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], + smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, + # hook_fn=load_V_next, + A_in_regs=self.Q_in_regs, + ) + scoremod_premask_fn(acc_S) + smem_pipe_write = self.advance_pipeline(smem_pipe_write) + def load_K_next(): + if n_block - self.num_stages >= 0: + load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) + cute.arch.cp_async_commit_group() + # wait for smem tile V for O + if const_expr(self.num_stages == 1): + sync() + load_K_next() + if const_expr(mask_fn is not None): + mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + softmax.rescale_O(mma_params.acc_O, row_scale) + rP = cute.make_fragment_like(acc_S, self.dtype) + rP.store(acc_S.load().to(self.dtype)) + tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout)) + if const_expr(self.num_stages > 1): + sync() + load_K_next() + sm80_utils.gemm_rs( + mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, + smem_copy_params.tOsVt[None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0], + smem_copy_params.smem_thr_copy_V, + # hook_fn=load_K_next, + ) + # if const_expr(self.num_stages > 1): + # load_K_next() + + +class FlashAttentionForwardSm90(FlashAttentionForwardBase): + + arch = 90 + + def __init__(self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.intra_wg_overlap = intra_wg_overlap + self.mma_pv_is_rs = mma_pv_is_rs + + def _get_smem_layout_atom(self): + sQ_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_padded + ), + self.dtype + ) + sK_layout_atom = sQ_layout_atom + sV_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.head_dim_v_padded + ), + self.dtype + ) + sO_layout_atom = sV_layout_atom + if not self.mma_pv_is_rs: + sP_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_basic.get_smem_layout_atom( + cutlass.utils.LayoutEnum.ROW_MAJOR, self.dtype, self.n_block_size + ), + self.dtype + ) + else: + sP_layout_atom = None + return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom + + def _get_tiled_mma(self): + tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.K, + Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.n_block_size), + ) + tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.head_dim_v_padded), + a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, + ) + tiled_mma_pv_rs = sm90_utils_basic.make_trivial_tiled_mma( + self.dtype, + self.dtype, + warpgroup.OperandMajorMode.K, + warpgroup.OperandMajorMode.MN, + Float32, + atom_layout_mnk=(self.m_block_size // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 + tiler_mn=(64, self.head_dim_v_padded), + a_source=warpgroup.OperandSource.RMEM + ) + return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs + + def _get_shared_storage_cls(self): + # If we use cp.async to load Q, we want sQ to align to 1024 bytes + sQ_alignment = 128 if const_expr(self.use_tma_Q) else 1024 + sK_alignment = 128 + sV_alignment = 128 + sQ_struct, sK_struct, sV_struct = [ + cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], alignment] + for layout, alignment in zip( + (self.sQ_layout, self.sK_layout, self.sV_layout), + (sQ_alignment, sK_alignment, sV_alignment) + ) + ] + cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) + sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] + cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 + sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] + # 1 for Q, 1 for O, self.num_stages*2 for K, self.num_stages*2 for V, + mbar_ptr_QO_struct = cute.struct.MemRange[cutlass.Int64, 2] + mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] + + @cute.struct + class SharedStorageQKV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sV: sV_struct + sQ: sQ_struct + sK: sK_struct + sP: sP_struct + + @cute.struct + class SharedStorageSharedQV: + mbar_ptr: mbar_ptr_QO_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_V: mbar_ptr_V_struct + sQ: sQV_struct + sK: sK_struct + sP: sP_struct + + return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + ): + """Configures and launches the flash attention kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) + """ + self._check_type( + *(t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) + ) + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) + for t in (mQ, mO) + ] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None + tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs = self._get_tiled_mma() + self.num_mma_threads = tiled_mma_qk.size + self.num_threads_per_warp_group = 128 + self.num_mma_warp_groups = self.num_mma_threads // self.num_threads_per_warp_group + self.num_threads = self.num_threads_per_warp_group * (self.num_mma_warp_groups + 1) + self.num_producer_threads = 32 + self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q + self.num_epilogue_threads = self.num_mma_threads + self.num_mma_regs = ( + 256 + if self.num_mma_warp_groups == 1 + else (240 if self.num_mma_warp_groups == 2 else 160) + ) + self.num_producer_regs = ( + 56 if self.num_mma_warp_groups == 1 else (24 if self.num_mma_warp_groups == 2 else 32) + ) + # self.num_mma_regs = 232 + # self.num_producer_regs = 40 + self.use_scheduler_barrier = (self.num_mma_warp_groups >= 2 and self.head_dim_padded <= 128) if const_expr(self.intra_wg_overlap) else (self.num_mma_warp_groups == 2) + self.use_tma_Q = self.arch >= 90 and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None and not self.pack_gqa + # TODO: rescale_O_before_gemm + self._setup_attributes() + SharedStorage = self._get_shared_storage_cls() + + if const_expr(self.pack_gqa): + shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) + stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) + mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) + shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) + stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) + mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + if const_expr(mLSE is not None): + shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) + stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) + mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + + # TMA + gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() + gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast + gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() + self.tma_copy_q_bytes = cute.size_in_bytes(mQ.element_type, cute.select(self.sQ_layout, mode=[0, 1])) + self.tma_copy_k_bytes = cute.size_in_bytes(mK.element_type, cute.select(self.sK_layout, mode=[0, 1])) + self.tma_copy_v_bytes = cute.size_in_bytes(mV.element_type, cute.select(self.sV_layout, mode=[0, 1])) + if const_expr(self.use_tma_Q): + tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_Q, mQ, self.sQ_layout, (self.m_block_size, self.head_dim_padded), # No mcast + ) + else: + tma_atom_Q, tma_tensor_Q = None, None + tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mK, + cute.select(self.sK_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim_padded), + 1 # No mcast for now + ) + tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_KV, + mV, + cute.select(self.sV_layout, mode=[0, 1]), + (self.n_block_size, self.head_dim_v_padded), + 1 # No mcast for now + ) + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + gmem_tiled_copy_O, mO, self.sO_layout, (self.m_block_size, self.head_dim_v_padded), # No mcast + ) + else: + tma_atom_O = None + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + TileScheduler = SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.m_block_size), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mK.shape[0]), + mQ.shape[1], + mV.shape[1], + total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=(self.m_block_size, self.n_block_size), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype.width // 8, + is_persistent=False, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + # Right after this, we multiply by log2(e) before applying exp2. + # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) + # (assigning it to softmax_scale_log2). + LOG2_E = math.log2(math.e) + if const_expr(softcap is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softcap_val = None + else: + softmax_scale_log2 = softcap * LOG2_E + softcap_val = Float32(softmax_scale / softcap) + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + self.kernel( + tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, + tma_tensor_K, + tma_tensor_V, + mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softcap_val, + window_size_left, + window_size_right, + learnable_sink, + self.sQ_layout, + self.sK_layout, + self.sV_layout, + self.sO_layout, + self.sP_layout, + self.gmem_tiled_copy_Q, + self.gmem_tiled_copy_K, + self.gmem_tiled_copy_V, + self.gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + tile_sched_params, + TileScheduler, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + tma_atom_Q: Optional[cute.CopyAtom], + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + sP_layout: cute.ComposedLayout | None, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_K: cute.TiledCopy, + gmem_tiled_copy_V: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Prefetch tma descriptor + if warp_idx == 0: + if const_expr(tma_atom_Q is not None): + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) + + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Mbarrier init + mbar_ptr_Q = storage.mbar_ptr.data_ptr() + if warp_idx == 1: + # if tidx < 2: + # # barrierO num threads should be self.num_mma_threads + # cute.arch.mbarrier_init(mbar_ptr_Q + tidx, 1 if tidx == 0 else self.num_mma_threads) + cute.arch.mbarrier_init(mbar_ptr_Q, 1 if const_expr(self.use_tma_Q) else self.num_Q_load_threads) + # cute.arch.mbarrier_init(mbar_ptr_Q + 1, self.num_mma_threads) + # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync + pipeline_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) + pipeline_kv_consumer_group = cutlass.pipeline.CooperativeGroup( + cutlass.pipeline.Agent.Thread, self.num_mma_threads // self.num_threads_per_warp_group + ) + pipeline_k = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_K.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_k_bytes, + init_wait=False, + ) + pipeline_v = pipeline.PipelineTmaAsyncNoCluster.create( + barrier_storage=storage.mbar_ptr_V.data_ptr(), + num_stages=self.num_stages, + producer_group=pipeline_kv_producer_group, + consumer_group=pipeline_kv_consumer_group, + tx_count=self.tma_copy_v_bytes, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + # TODO: how to get sQ_pi for cp.async if pack_gqa? + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + if const_expr(not self.Q_in_regs): + sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) + else: + sV = storage.sQ.get_tensor(sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type) + # Transpose view of V to tensor with layout (head_dim_v, n_block_size) for tiled mma + sVt = utils.transpose_view(sV) + if const_expr(sP_layout is not None): + sP_pi = storage.sP.get_tensor(sP_layout) + sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) + else: + sP, sP_pi = None, None + # reuse sQ's data iterator + sO_pi = storage.sQ.get_tensor(sO_layout) + # TODO: idk why not using sO_pi is faster + sO = cute.make_tensor(cute.recast_ptr(sO_pi.iterator, sO_layout.inner, dtype=sO_pi.element_type), sO_layout.outer) + + block_info = BlockInfo( + self.m_block_size, self.n_block_size, self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, self.m_block_size, self.n_block_size, + window_size_left=window_size_left, window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) + + if warp_idx < 4: # Producer + cute.arch.warpgroup_reg_dealloc(self.num_producer_regs) + self.load( + mQ, + mK, + mV, + sQ, + sK, + sV, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + else: # Consumer + cute.arch.warpgroup_reg_alloc(self.num_mma_regs) + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - 128 + self.mma( + tiled_mma_qk, + tiled_mma_pv, + tiled_mma_pv_rs, + mQ, + mO, + mLSE, + sQ, + sK, + sVt, + sP, + sO, + learnable_sink, + pipeline_k, + pipeline_v, + mbar_ptr_Q, + gmem_tiled_copy_Q, + gmem_tiled_copy_O, + tma_atom_O, + tidx, + softmax_scale_log2, + softcap_val, + block_info, + SeqlenInfoCls, + AttentionMaskCls, + TileSchedulerCls, + ) + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + if warp_idx_in_wg == 0: + q_producer_phase = Int32(1) + kv_producer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Producer, self.num_stages + ) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, None, head_idx_kv]) for t in (mK, mV)] + gK = cute.local_tile(mK_cur, (self.n_block_size, self.head_dim_padded), (None, 0)) + gV = cute.local_tile(mV_cur, (self.n_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(self.use_tma_Q): + gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 2), + cute.group_modes(gQ, 0, 2), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 2), + cute.group_modes(gK, 0, 2), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 2), + cute.group_modes(gV, 0, 2), + ) + load_K = partial(self.load_K, tma_atom_K, tKgK, tKsK, pipeline_k) + load_V = partial(self.load_K, tma_atom_V, tVgV, tVsV, pipeline_v) + # load_Q + if const_expr(self.use_tma_Q): + # TODO: wait for Q to be empty + q_producer_phase ^= 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_ptr_Q, self.tma_copy_q_bytes) + cute.copy(tma_atom_Q, tQgQ, tQsQ, tma_bar_ptr=mbar_ptr_Q) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) + for i in cutlass.range(n_block_max - n_block_min, unroll=2): + n_block = n_block_max - i - 1 + load_K(n_block, producer_state=kv_producer_state) + load_V(n_block, producer_state=kv_producer_state) + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + # softmax: Softmax, + # acc_O: cute.Tensor, + mQ: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + sQ: cute.Tensor, + sK: cute.Tensor, + sVt: cute.Tensor, + sP: Optional[cute.Tensor], + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mbar_ptr_Q: cutlass.Pointer, + gmem_tiled_copy_Q: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + tidx: Int32, + softmax_scale_log2: Float32, + softcap_val: Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + ): + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma_qk = tiled_mma_qk.get_slice(tidx) + wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) + wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) + tSrQ = tiled_mma_qk.make_fragment_A(wg_mma_qk.partition_A(sQ)) + tSrK = tiled_mma_qk.make_fragment_B(wg_mma_qk.partition_B(sK)) + if const_expr(self.mma_pv_is_rs): + acc_S_shape = tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)) + tOrP = cute.make_fragment( + utils.convert_layout_acc_frgA(cute.make_layout(acc_S_shape)), self.dtype + ) + else: + tOrP = tiled_mma_pv.make_fragment_A(wg_mma_pv.partition_A(sP)) + tOrVt = tiled_mma_pv.make_fragment_B(wg_mma_pv.partition_B(sVt)) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_P = utils.get_smem_store_atom(self.arch, self.dtype) + smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) + # tPsP = smem_thr_copy_P.partition_D(sP_pi) if const_expr(sP_pi is not None) else None + tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None + # if cute.arch.thread_idx()[0] == 0: + # cute.printf(sP_pi.layout, sP_pi.iterator) + # cute.printf(sP.layout, sP.iterator) + # cute.printf(tPsP.layout, tPsP.iterator) + + self.mma_init() + + acc_shape_O = tiled_mma_pv.partition_shape_C((self.m_block_size, self.head_dim_v_padded)) + acc_O = cute.make_fragment(acc_shape_O, Float32) + # group parameters for mma_one_n_block + mma_params = SimpleNamespace(tSrQ=tSrQ, tSrK=tSrK, tOrP=tOrP, tOrVt=tOrVt, acc_O=acc_O) + smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) + + mma_one_n_block_all = partial( + self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, + tiled_mma_qk=tiled_mma_qk, tiled_mma_pv=tiled_mma_pv, tiled_mma_pv_rs=tiled_mma_pv_rs, + pipeline_k=pipeline_k, pipeline_v=pipeline_v, + mma_params=mma_params, smem_copy_params=smem_copy_params, + check_inf=True, + ) + + q_consumer_phase = Int32(0) + kv_consumer_state = pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.num_stages + ) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # if work_tile.is_valid_tile: + # Softcapping needs to happen before masking since if we apply after masking, softcapping can turn + # -inf to e.g. -50.0, which can affect the attention softmax. + def scoremod_premask_fn(acc_S): + if const_expr(softcap_val is not None): + acc_S.store(cute.math.tanh(acc_S.load() * softcap_val, fastmath=True)) + + # shape: (atom_v_m * rest_m) + softmax = Softmax(softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1]) + mma_one_n_block = partial( + mma_one_n_block_all, softmax=softmax, scoremod_premask_fn=scoremod_premask_fn + ) + + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask, m_block=m_block, thr_mma=thr_mma_qk, + mask_causal=self.is_causal, mask_local=self.is_local, + ) + softmax.reset() + # Load Q if not TMA_Q + if const_expr(not self.use_tma_Q): + pack_gqa = PackGQA(self.m_block_size, self.head_dim_padded, self.check_hdim_oob, self.qhead_per_kvhead) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) + # gQ = cute.local_tile(mQ_cur, (self.m_block_size, self.head_dim_padded), (m_block, 0)) + # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, + # headdim=mQ.shape[1]) + pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) + utils.cp_async_mbarrier_arrive_shared(mbar_ptr_Q, noinc=True) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) + q_consumer_phase ^= 1 + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last several blocks. + O_should_accumulate = False + # First iteration with seqlen masking + if const_expr(self.intra_wg_overlap): + acc_S = cute.make_fragment( + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 + ) + pipeline_k.consumer_wait(kv_consumer_state) + sm90_utils.gemm( + tiled_mma_qk, acc_S, tSrQ, tSrK[None, None, None, kv_consumer_state.index], + zero_init=True, wg_wait=0 + ) + pipeline_k.consumer_release(kv_consumer_state) + scoremod_premask_fn(acc_S) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + softmax.online_softmax(acc_S, is_first=True) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + # the "to(self.dtype)" conversion fails to vectorize for block sizes other + # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of + # 2 elements. So we just call ptx directly. + utils.cvt_f16(tOrP_acc, tOrP) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_thr_copy_P.retile(tOrP) + cute.copy(smem_thr_copy_P, tPrP, tPsP) + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + # Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + # acc_O.fill(0.0) + else: + self.warp_scheduler_barrier_sync() + kv_consumer_state = mma_one_n_block( + n_block_max - 1, kv_consumer_state, + is_first_n_block=True, mask_fn=partial(mask_fn, mask_seqlen=True), + O_should_accumulate=False + ) + O_should_accumulate = True + # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + kv_consumer_state = mma_one_n_block( + n_block, kv_consumer_state, mask_fn=partial(mask_fn, mask_seqlen=False), + O_should_accumulate=O_should_accumulate + ) + O_should_accumulate = True + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + kv_consumer_state = mma_one_n_block(n_block, kv_consumer_state, check_inf=True, O_should_accumulate=O_should_accumulate) + O_should_accumulate = True + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + kv_consumer_state = mma_one_n_block( + n_block, kv_consumer_state, + check_inf=True, mask_fn=partial(mask_fn, mask_seqlen=False), + O_should_accumulate=O_should_accumulate + ) + O_should_accumulate = True + # Last "half" iteration + if const_expr(self.intra_wg_overlap): + pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) + sm90_utils.gemm( + tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, + mma_params.tOrVt[None, None, None, kv_consumer_state.index], + zero_init=not O_should_accumulate, wg_wait=-1 + ) + warpgroup.wait_group(0) + pipeline_v.consumer_release(kv_consumer_state) + kv_consumer_state.advance() + else: + self.warp_scheduler_barrier_arrive() + + # normalize acc_O by row_sum and calculate the lse + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + else: # Each thread might have a different sink value due to different q_head + sink_val = cute.make_fragment_like(softmax.row_max, Float32) + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS)) + for r in cutlass.range(cute.size(sink_val), unroll_full=True): + row = m_block * self.m_block_size + tScS_mn[r][0] + q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + sink_val[r] = Float32(learnable_sink[q_head_idx]) + else: + sink_val = None + + row_scale = softmax.finalize(sink_val=sink_val) + softmax.rescale_O(acc_O, row_scale) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + self.epilogue( + acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, + gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, + ) + + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + @cute.jit + def mma_one_n_block( + self, + n_block: Int32, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + scoremod_premask_fn: Callable, + mask_fn: Optional[Callable] = None, + is_first_n_block: cutlass.Constexpr = False, + check_inf: cutlass.Constexpr = True, + O_should_accumulate: cutlass.Boolean = True, + ): + acc_S = cute.make_fragment( + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 + ) + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + sm90_utils.gemm( + tiled_mma_qk, acc_S, mma_params.tSrQ, + mma_params.tSrK[None, None, None, smem_pipe_read.index], + zero_init=True, wg_wait=-1 + ) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(0) + pipeline_k.consumer_release(smem_pipe_read) + scoremod_premask_fn(acc_S) + if const_expr(mask_fn is not None): + mask_fn(acc_S, n_block=n_block) + row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) + # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + utils.cvt_f16(tOrP_acc, tOrP) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(mma_params.acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + sm90_utils.gemm( + tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, + mma_params.tOrVt[None, None, None, smem_pipe_read.index], + zero_init=not O_should_accumulate, wg_wait=0 + ) + pipeline_v.consumer_release(smem_pipe_read) + smem_pipe_read.advance() + return smem_pipe_read + + @cute.jit + def mma_one_n_block_intrawg_overlap( + self, + n_block: Int32, + smem_pipe_read: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tiled_mma_pv_rs: cute.TiledMma, + pipeline_k: cutlass.pipeline.PipelineAsync, + pipeline_v: cutlass.pipeline.PipelineAsync, + mma_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax: Softmax, + scoremod_premask_fn: Callable, + mask_fn: Optional[Callable] = None, + check_inf: cutlass.Constexpr = True, + O_should_accumulate: cutlass.Boolean = True, + ): + smem_pipe_read_v = smem_pipe_read.clone() + smem_pipe_read.advance() + acc_S = cute.make_fragment( + tiled_mma_qk.partition_shape_C((self.m_block_size, self.n_block_size)), Float32 + ) + pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) + self.warp_scheduler_barrier_sync() + sm90_utils.gemm( + tiled_mma_qk, acc_S, mma_params.tSrQ, + mma_params.tSrK[None, None, None, smem_pipe_read.index], + zero_init=True, wg_wait=-1 + ) + pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) + sm90_utils.gemm( + tiled_mma_pv, mma_params.acc_O, mma_params.tOrP, + mma_params.tOrVt[None, None, None, smem_pipe_read_v.index], + zero_init=not O_should_accumulate, wg_wait=-1 + ) + self.warp_scheduler_barrier_arrive() + warpgroup.wait_group(1) + pipeline_k.consumer_release(smem_pipe_read) + scoremod_premask_fn(acc_S) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + if const_expr(mask_fn is not None): + mask_fn(acc_S, n_block=n_block) + # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S)) + row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) + warpgroup.wait_group(0) + pipeline_v.consumer_release(smem_pipe_read_v) + tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout)) + tOrP = mma_params.tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype) + # tOrP.store(tOrP_acc.load().to(self.dtype)) + utils.cvt_f16(tOrP_acc, tOrP) + if const_expr(not self.mma_pv_is_rs): + tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP) + cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) + softmax.rescale_O(mma_params.acc_O, row_scale) + if const_expr(not self.mma_pv_is_rs): + # Fence and barrier to make sure smem store is visible to WGMMA + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV + return smem_pipe_read + + @cute.jit + def mma_init(self): + warp_group_idx = utils.canonical_warp_group_idx(sync=False) + if const_expr(self.use_scheduler_barrier): + if warp_group_idx == 1: + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def warp_scheduler_barrier_sync(self): + if const_expr(self.use_scheduler_barrier): + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), + number_of_threads=2 * self.num_threads_per_warp_group + ) + + def warp_scheduler_barrier_arrive(self): + if const_expr(self.use_scheduler_barrier): + assert self.num_mma_warp_groups in [2, 3] + cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 + if const_expr(self.num_mma_warp_groups == 2): + next_wg = 1 - cur_wg + else: + t = cur_wg + 1 + next_wg = t % self.num_mma_warp_groups + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + # @cute.jit + def load_K( + self, + tma_atom: cute.CopyAtom, + tKgK: cute.Tensor, + tKsK: cute.Tensor, + pipeline: cutlass.pipeline.PipelineAsync, + block: Int32, + producer_state: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, + ): + # TODO: mcast + # TODO check warp_idx if we have 128 producer threads + pipeline.producer_acquire(producer_state) + cute.copy( + tma_atom, + tKgK[None, block], + tKsK[None, producer_state.index], + tma_bar_ptr=pipeline.producer_get_barrier(producer_state) + ) diff --git a/flash_attn/cute/flash_fwd_combine.py b/flash_attn/cute/flash_fwd_combine.py new file mode 100644 index 0000000000..4c423b8096 --- /dev/null +++ b/flash_attn/cute/flash_fwd_combine.py @@ -0,0 +1,644 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h +# from Cutlass C++ to Cute-DSL. +import math +import operator +from typing import Type, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync +from cutlass import Float32, Int32, const_expr + +from flash_attn.cute import utils +from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.seqlen_info import SeqlenInfo + + +class FlashAttentionForwardCombine: + def __init__( + self, + dtype: Type[cutlass.Numeric], + dtype_partial: Type[cutlass.Numeric], + head_dim: int, + m_block_size: int = 8, + k_block_size: int = 64, + log_max_splits: int = 4, + num_threads: int = 256, + stages: int = 4, + ): + """ + Forward combine kernel for split attention computation. + + :param dtype: output data type + :param dtype_partial: partial accumulation data type + :param head_dim: head dimension + :param m_block_size: m block size + :param k_block_size: k block size + :param log_max_splits: log2 of maximum splits + :param num_threads: number of threads + :param varlen: whether using variable length sequences + :param stages: number of pipeline stages + """ + self.dtype = dtype + self.dtype_partial = dtype_partial + self.head_dim = head_dim + self.m_block_size = m_block_size + self.k_block_size = k_block_size + self.max_splits = 1 << log_max_splits + self.num_threads = num_threads + self.is_even_k = head_dim % k_block_size == 0 + self.stages = stages + + @staticmethod + def can_implement( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, + log_max_splits, num_threads, + ) -> bool: + """Check if the kernel can be implemented with the given parameters.""" + if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: + return False + if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: + return False + if head_dim % 8 != 0: + return False + if num_threads % 32 != 0: + return False + if m_block_size % 8 != 0: + return False + max_splits = 1 << log_max_splits + if max_splits > 256: + return False + if (m_block_size * max_splits) % num_threads != 0: + return False + return True + + def _setup_attributes(self): + # GMEM copy setup for O partial + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.dtype_partial.width + assert self.k_block_size % async_copy_elems == 0 + + k_block_gmem = ( + 128 if self.k_block_size % 128 == 0 else + (64 if self.k_block_size % 64 == 0 else 32) + ) + gmem_threads_per_row = k_block_gmem // async_copy_elems + assert self.num_threads % gmem_threads_per_row == 0 + + # Async copy atom for O partial load + atom_async_copy_partial = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self.dtype_partial, + num_bits_per_copy=universal_copy_bits, + ) + tOpartial_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load + self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( + atom_async_copy_partial, tOpartial_layout, vOpartial_layout + ) + + # GMEM copy setup for final O (use universal copy for store) + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=async_copy_elems * self.dtype.width, + ) + self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store + ) + + # LSE copy setup with async copy (alignment = 1) + lse_copy_bits = Float32.width # 1 element per copy, width is in bits + m_block_smem = ( + 128 if self.m_block_size % 128 == 0 else + (64 if self.m_block_size % 64 == 0 else + (32 if self.m_block_size % 32 == 0 else + (16 if self.m_block_size % 16 == 0 else 8))) + ) + gmem_threads_per_row_lse = m_block_smem + assert self.num_threads % gmem_threads_per_row_lse == 0 + + # Async copy atom for LSE load + atom_async_copy_lse = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + Float32, + num_bits_per_copy=lse_copy_bits, + ) + tLSE_layout = cute.make_ordered_layout( + (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), + order=(1, 0), + ) + vLSE_layout = cute.make_layout(1) + self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( + atom_async_copy_lse, tLSE_layout, vLSE_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory + # /////////////////////////////////////////////////////////////////////////////// + + # Shared memory to register copy for LSE + self.smem_threads_per_col_lse = self.num_threads // m_block_smem + assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size + + s2r_layout_atom_lse = cute.make_ordered_layout( + (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), + order=(0, 1), + ) + self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), + s2r_layout_atom_lse, + cute.make_layout(1), + ) + + # LSE shared memory layout with swizzling to avoid bank conflicts + # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts + if const_expr(m_block_smem == 8): + smem_lse_swizzle = cute.make_swizzle(5, 0, 5) + elif const_expr(m_block_smem == 16): + smem_lse_swizzle = cute.make_swizzle(4, 0, 4) + else: + smem_lse_swizzle = cute.make_swizzle(3, 2, 3) + smem_layout_atom_lse = cute.make_composed_layout( + smem_lse_swizzle, + 0, + cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) + ) + self.smem_layout_lse = cute.tile_to_shape( + smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1) + ) + + # O partial shared memory layout (simple layout for pipeline stages) + self.smem_layout_o = cute.make_ordered_layout( + (self.m_block_size, self.k_block_size, self.stages), + order=(1, 0, 2) + ) + + + @cute.jit + def __call__( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor] = None, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + num_splits_dynamic_ptr: Optional[cute.Tensor] = None, + semaphore_to_reset: Optional[cute.Tensor] = None, + stream: cuda.CUstream = None, + ): + # Type checking + if const_expr(not (mO_partial.element_type == self.dtype_partial)): + raise TypeError("O partial tensor must match dtype_partial") + if const_expr(not (mO.element_type == self.dtype)): + raise TypeError("O tensor must match dtype") + if const_expr(not mLSE_partial.element_type in [Float32]): + raise TypeError("LSE partial tensor must be Float32") + if const_expr(mLSE is not None and not mLSE.element_type in [Float32]): + raise TypeError("LSE tensor must be Float32") + + # Shape validation - input tensors are in user format, need to be converted to kernel format + if const_expr(len(mO_partial.shape) not in [4, 5]): + raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)") + if const_expr(len(mLSE_partial.shape) not in [3, 4]): + raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)") + if const_expr(len(mO.shape) not in [3, 4]): + raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)") + if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): + raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)") + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)] + # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) + # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) + O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] + # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) + mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)) + O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] + mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) + # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) + # or (num_splits, total_q, h) -> (total_q, num_splits, h) + LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] + mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose)) + # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) + LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None + + # Determine if we have variable length sequences + varlen = const_expr(cu_seqlens is not None or seqused is not None) + + self._setup_attributes() + + @cute.struct + class SharedStorage: + sLSE: cute.struct.Align[ + cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 + ] + sMaxValidSplit: cute.struct.Align[ + cute.struct.MemRange[Int32, self.m_block_size], 128 + ] + sO: cute.struct.Align[ + cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 + ] + + smem_size = SharedStorage.size_in_bytes() + + # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) + seqlen = mO_partial.shape[0] + num_head = mO_partial.shape[3] + batch_size = mO_partial.shape[4] + + # Create FastDivmod objects for efficient division + seqlen_divmod = FastDivmod.create(seqlen) + head_divmod = FastDivmod.create(num_head) + + grid_dim = ( + cute.ceil_div(seqlen * num_head, self.m_block_size), + cute.ceil_div(self.head_dim, self.k_block_size), + batch_size, + ) + + self.kernel( + mO_partial, + mLSE_partial, + mO, + mLSE, + cu_seqlens, + seqused, + num_splits_dynamic_ptr, + semaphore_to_reset, + SharedStorage, + self.smem_layout_lse, + self.smem_layout_o, + self.gmem_tiled_copy_O_partial, + self.gmem_tiled_copy_O, + self.gmem_tiled_copy_LSE, + self.s2r_tiled_copy_LSE, + seqlen_divmod, + head_divmod, + varlen, + ).launch( + grid=grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mO_partial: cute.Tensor, + mLSE_partial: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + cu_seqlens: Optional[cute.Tensor], + seqused: Optional[cute.Tensor], + num_splits_dynamic_ptr: Optional[cute.Tensor], + semaphore_to_reset: Optional[cute.Tensor], + SharedStorage: cutlass.Constexpr, + smem_layout_lse: cute.Layout | cute.ComposedLayout, + smem_layout_o: cute.Layout, + gmem_tiled_copy_O_partial: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + gmem_tiled_copy_LSE: cute.TiledCopy, + s2r_tiled_copy_LSE: cute.TiledCopy, + seqlen_divmod: FastDivmod, + head_divmod: FastDivmod, + varlen: cutlass.Constexpr[bool], + ): + # Thread and block indices + tidx, _, _ = cute.arch.thread_idx() + m_block, k_block, batch_idx = cute.arch.block_idx() + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sLSE = storage.sLSE.get_tensor(smem_layout_lse) + sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.m_block_size,)) + sO = storage.sO.get_tensor(smem_layout_o) + + # Handle semaphore reset + if const_expr(semaphore_to_reset is not None): + if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and + k_block == cute.arch.grid_dim()[1] - 1 and + batch_idx == cute.arch.grid_dim()[2] - 1): + semaphore_to_reset[0] = 0 + + # Get number of splits + num_splits = ( + num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None) + else mLSE_partial.shape[1] + ) + # Handle variable length sequences using SeqlenInfo + seqlen_info = SeqlenInfo( + batch_idx=batch_idx, + seqlen_static=mO_partial.shape[0], + cu_seqlens=cu_seqlens, + seqused=seqused + ) + seqlen, offset = seqlen_info.seqlen, seqlen_info.offset + + # Extract number of heads (head index will be determined dynamically) + num_head = mO_partial.shape[3] + max_idx = seqlen * num_head + + # Early exit for single split if dynamic + if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx): + + # =============================== + # Step 1: Load LSE_partial from gmem to shared memory + # =============================== + + if const_expr(cu_seqlens is None): + # mLSE_partial_cur = mLSE_partial[None, None, None, batch_idx] + mLSE_partial_cur = utils.coord_offset_i64(mLSE_partial, batch_idx, dim=3) + else: + # mLSE_partial_cur = cute.domain_offset((offset, 0, 0), mLSE_partial) + mLSE_partial_cur = utils.domain_offset_i64((offset, 0, 0), mLSE_partial) + mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) + + gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) + tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) + + # Create identity tensor for coordinate tracking + cLSE = cute.make_identity_tensor((self.max_splits, self.m_block_size)) + tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) + + # Load LSE partial values + for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): + mi = tLSEcLSE[0, 0, m][1] # Get m coordinate + idx = m_block * self.m_block_size + mi + if idx < max_idx: + # Calculate actual sequence position and head using FastDivmod + if const_expr(not varlen): + head_idx, m_idx = seqlen_divmod.divmod(idx) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] + for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): + si = tLSEcLSE[0, s, 0][0] # Get split coordinate + if si < num_splits: + cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m]) + else: + tLSEsLSE[None, s, m].fill(-Float32.inf) + # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem + cute.arch.cp_async_commit_group() + + # =============================== + # Step 2: Load O_partial for pipeline stages + # =============================== + + gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.k_block_size)) + tOcO = gmem_thr_copy_O_partial.partition_D(cO) + tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) + if const_expr(cu_seqlens is None): + # mO_partial_cur = mO_partial[None, None, None, None, batch_idx] + mO_partial_cur = utils.coord_offset_i64(mO_partial, batch_idx, dim=4) + else: + # mO_partial_cur = cute.domain_offset((offset, 0, 0, 0), mO_partial) + mO_partial_cur = utils.domain_offset_i64((offset, 0, 0, 0), mO_partial) + + # Precompute these values to avoid recomputing them in the loop + num_rows = const_expr(cute.size(tOcO, mode=[1])) + tOmidx = cute.make_fragment(num_rows, cutlass.Int32) + tOhidx = cute.make_fragment(num_rows, cutlass.Int32) + tOrOptr = cute.make_fragment(num_rows, cutlass.Int64) + for m in cutlass.range(num_rows, unroll_full=True): + mi = tOcO[0, m, 0][0] # m coordinate + idx = m_block * self.m_block_size + mi + if const_expr(not varlen): + tOhidx[m], tOmidx[m] = seqlen_divmod.divmod(idx) + else: + tOhidx[m] = idx // seqlen + tOmidx[m] = idx - tOhidx[m] * seqlen + tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint() + if idx >= max_idx: + tOhidx[m] = -1 + + tOpO = cute.make_fragment(cute.size(tOcO, [2]), cutlass.Boolean) + if const_expr(not self.is_even_k): + for k in cutlass.range(cute.size(tOpO), unroll_full=True): + tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size + # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) + + load_O_partial = partial( + self.load_O_partial, + gmem_tiled_copy_O_partial, + tOrOptr, + tOsO_partial, + tOhidx, + tOpO, + tOcO, + mO_partial_cur.layout, + ) + + # Load first few stages of O_partial + for stage in cutlass.range(self.stages - 1, unroll_full=True): + if stage < num_splits: + load_O_partial(stage, stage) + cute.arch.cp_async_commit_group() + + # =============================== + # Step 3: Load and transpose LSE from smem to registers + # =============================== + + # Wait for LSE and initial O partial stages to complete + cute.arch.cp_async_wait_group(self.stages - 1) + cute.arch.sync_threads() + # if cute.arch.thread_idx()[0] == 0: + # # cute.print_tensor(sLSE) + # for i in range(64): + # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) + # cute.arch.sync_threads() + + s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) + ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) + ts2rrLSE = cute.make_fragment_like(ts2rsLSE) + cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) + + # =============================== + # Step 4: Compute final LSE along split dimension + # =============================== + + lse_sum = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Float32) + ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) + # We compute the max valid split for each row to short-circuit the computation later + max_valid_split = cute.make_fragment(cute.size(ts2rrLSE, mode=[2]), Int32) + assert cute.size(ts2rrLSE, mode=[0]) == 1 + # Compute max, scales, and final LSE for each row + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + # Find max LSE value across splits + threads_per_col = const_expr(self.smem_threads_per_col_lse) + lse_max = utils.warp_reduce( + ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + op=cute.arch.fmax, + width=threads_per_col, + ) + # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) + # Find max valid split index + max_valid_idx = -1 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + if ts2rrLSE[0, s, m] != -Float32.inf: + max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate + # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) + max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col) + # Compute exp scales and sum + lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf + LOG2_E = math.log2(math.e) + lse_sum_cur = 0.0 + for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): + scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E)) + lse_sum_cur += scale + ts2rrLSE[0, s, m] = scale # Store scale for later use + lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col) + lse_sum[m] = utils.logf(lse_sum_cur) + lse_max + # Normalize scales + inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur + ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) + # Store the scales exp(lse - lse_logsum) back to smem + cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) + + # Store max valid split to smem + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + if mi < self.m_block_size: + sMaxValidSplit[mi] = max_valid_split[m] + + # =============================== + # Step 5: Store final LSE to gmem + # =============================== + + if const_expr(mLSE is not None): + if const_expr(cu_seqlens is None): + # mLSE_cur = mLSE[None, None, batch_idx] + mLSE_cur = utils.coord_offset_i64(mLSE, batch_idx, dim=2) + else: + # mLSE_cur = cute.domain_offset((offset, 0), mLSE) + mLSE_cur = utils.domain_offset_i64((offset, 0), mLSE) + if k_block == 0: # Only first k_block writes LSE when mLSE is provided + for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): + if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes + mi = ts2rcLSE[0, 0, m][1] + idx = m_block * self.m_block_size + mi + if idx < max_idx: + if const_expr(not varlen): + head_idx, m_idx = seqlen_divmod.divmod(idx) + else: + head_idx = idx // seqlen + m_idx = idx - head_idx * seqlen + mLSE_cur[m_idx, head_idx] = lse_sum[m] + + # =============================== + # Step 6: Read O_partial and accumulate final O + # =============================== + + cute.arch.sync_threads() + + # Get max valid split for this thread + thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] + for m in cutlass.range(1, cute.size(tOcO, mode=[1])): + thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) + + tOrO_partial = cute.make_fragment_like(tOsO_partial[None, None, None, 0]) + tOrO = cute.make_fragment_like(tOrO_partial, Float32) + tOrO.fill(0.0) + + stage_load = self.stages - 1 + stage_compute = 0 + + # Main accumulation loop + for s in cutlass.range(thr_max_valid_split + 1, unroll=4): + # Get scales for this split + scale = cute.make_fragment(num_rows, Float32) + for m in cutlass.range(num_rows, unroll_full=True): + scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem + + # Load next stage if needed + split_to_load = s + self.stages - 1 + if split_to_load <= thr_max_valid_split: + load_O_partial(split_to_load, stage_load) + cute.arch.cp_async_commit_group() + stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 + + # Wait for the current stage to be ready + cute.arch.cp_async_wait_group(self.stages - 1) + # We don't need __syncthreads() because each thread is just reading its own data from smem + # Copy from smem to registers + cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) + stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 + + # Accumulate scaled partial results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0 and scale[m] > 0.0: + tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32)) + + # =============================== + # Step 7: Write final O to gmem + # =============================== + + rO = cute.make_fragment_like(tOrO, self.dtype) + rO.store(tOrO.load().to(self.dtype)) + if const_expr(cu_seqlens is None): + # mO_cur = mO[None, None, None, batch_idx] + mO_cur = utils.coord_offset_i64(mO, batch_idx, dim=3) + else: + # mO_cur = cute.domain_offset((offset, 0, 0), mO) + mO_cur = utils.domain_offset_i64((offset, 0, 0), mO) + mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) + elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) + # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + # Write final results + for m in cutlass.range(num_rows, unroll_full=True): + if tOhidx[m] >= 0: + mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_store + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) + + @cute.jit + def load_O_partial( + self, + gmem_tiled_copy_O_partial: cute.TiledCopy, + tOrOptr: cute.Tensor, + tOsO_partial: cute.Tensor, + tOhidx: cute.Tensor, + tOpO: cute.Tensor, + tOcO: cute.Tensor, + mO_cur_partial_layout: cute.Layout, + split: Int32, + stage: Int32, + ) -> None: + elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) + tOsO_partial_cur = tOsO_partial[None, None, None, stage] + for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): + if tOhidx[m] >= 0: + o_gmem_ptr = cute.make_ptr( + tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 + ) + mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))) + mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) + for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): + k_idx = tOcO[0, 0, k][1] // elems_per_load + if const_expr(self.is_even_k) or tOpO[k]: + cute.copy( + gmem_tiled_copy_O_partial, + # mO_partial_cur_copy[None, k_idx, split], + utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx], + tOsO_partial_cur[None, m, k] + ) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py new file mode 100644 index 0000000000..186b219031 --- /dev/null +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -0,0 +1,1909 @@ +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128, (192, 128). +# - varlen +# - sliding window +# Unsupported features that will be added later: +# - split-kv (optimizing for inference) +# - more hdim (192, 256) +# Based on the cutlass example and cute-dsl example: +# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py + +import enum +import math +from typing import Type, Tuple, Callable, Optional +from functools import partial + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.nvgpu import cpasync +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils_basic + +import flash_attn.cute.utils as utils +# import flash_attn.cute.pipeline as pipeline +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.fast_math import FastDivmod +from flash_attn.cute.tile_scheduler import TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ParamsBase + + +# class NamedBarrierFwd(enum.IntEnum): +# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +# WarpSchedulerWG1 = enum.auto() +# WarpSchedulerWG2 = enum.auto() +# WarpSchedulerWG3 = enum.auto() +# PFull = enum.auto() +# PEmpty = enum.auto() + + +class FlashAttentionForwardSm100: + + arch = 100 + + def __init__( + self, + # dtype: Type[cutlass.Numeric], + head_dim: int, + head_dim_v: Optional[int] = None, + qhead_per_kvhead: cutlass.Constexpr[int] = 1, + is_causal: bool = False, + is_local: bool = False, + pack_gqa: bool = False, + m_block_size: int = 128, + n_block_size: int = 128, + is_persistent: bool = True, + ): + # self.dtype = dtype + # padding head_dim to a multiple of 16 as k_block_size + hdim_multiple_of = 16 + self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) + head_dim_v = head_dim_v if head_dim_v is not None else head_dim + self.same_hdim_kv = head_dim == head_dim_v + self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) + self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded + self.check_hdim_oob = head_dim != self.head_dim_padded + self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded + self.m_block_size = m_block_size + self.n_block_size = n_block_size + self.q_stage = 2 + assert self.q_stage in [1, 2] + + # 2 Q tile per CTA + self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_qk = (m_block_size, n_block_size, self.head_dim_padded) + self.mma_tiler_pv = (m_block_size, self.head_dim_v_padded, n_block_size) + self.qk_acc_dtype = Float32 + self.pv_acc_dtype = Float32 + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.is_causal = is_causal + self.is_local = is_local + self.qhead_per_kvhead = qhead_per_kvhead + self.pack_gqa = pack_gqa + if pack_gqa: + assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" + # Does S1 need to wait for S0 to finish + # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) + self.s0_s1_barrier = False + self.overlap_sO_sQ = self.head_dim_padded == 192 and self.head_dim_v_padded >= 64 + if self.overlap_sO_sQ: + assert self.head_dim_padded >= self.head_dim_v_padded # We assume sQ is larger than sO + self.is_persistent = False + + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epilogue_warp_ids = (14,) + self.empty_warp_ids = (15,) + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_cta = cute.arch.WARP_SIZE * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + *self.epilogue_warp_ids, + *self.empty_warp_ids, + ) + ) + + self.tmem_alloc_sync_bar_id = 1 + + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 + self.tmem_o_offset = [self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage)] # e.g., 256, 384 + self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded + assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + self.tmem_s_to_p_offset = self.n_block_size // 2 + self.tmem_p_offset = [self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)] # 0, 128 + + # vec buffer for row_max & row_sum + self.tmem_vec_offset = self.tmem_s_offset + + if self.head_dim_padded < 96: + self.num_regs_softmax = 200 + self.num_regs_correction = 64 + self.num_regs_other = 48 + else: + self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 + # self.num_regs_softmax = 176 + # self.num_regs_correction = 96 + # self.num_regs_correction = 80 + # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 + self.num_regs_correction = 64 + # self.num_regs_other = 32 + # self.num_regs_other = 64 + # self.num_regs_other = 80 + # self.num_regs_other = 48 + # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 + self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.epi_stage = 2 + # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: + # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. + # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is + # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be + # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, + # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. + self.uneven_kv_smem = self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 + self.uneven_kv_smem_offset = self.m_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 + assert self.uneven_kv_smem_offset % 1024 == 0 + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], + softmax_scale: Float32, + stream: cuda.CUstream, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) + softcap: Float32 | float | None = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + learnable_sink: Optional[cute.Tensor] = None, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = mQ.element_type + self.k_dtype = mK.element_type + self.v_dtype = mV.element_type + self.o_dtype = mO.element_type + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1]) + mQ, mK, mV, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mQ, mK, mV, mO)] + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ, mO = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) + for t in (mQ, mO) + ] + # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mK, mV = [ + cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) + for t in (mK, mV) + ] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None + # (s, d, h, b) -> (d, s, h, b) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) + + self.q_major_mode = cutlass.utils.LayoutEnum.from_tensor(mQ).mma_major_mode() + self.k_major_mode = cutlass.utils.LayoutEnum.from_tensor(mK).mma_major_mode() + self.v_major_mode = cutlass.utils.LayoutEnum.from_tensor(mV).mma_major_mode() + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + if const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mQ is not supported") + if const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of mK is not supported") + if const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of mV is not supported") + + # check type consistency + if const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None and mSeqUsedQ is None + # This can be tuned + self.e2e_freq = 16 + if const_expr(self.head_dim_padded > 64 and not self.is_causal and not self.is_local and self.pack_gqa): + self.e2e_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & mK-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.mma_tiler_qk[:2], + ) + tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.mma_tiler_pv[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_qk.thr_id.shape,), + ) + + self.epi_tile = self.mma_tiler_pv[:2] + + sQ_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage, + ) + sK_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage, + ) + tP_layout = sm100_utils_basic.make_smem_layout_a( + tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.acc_stage, + ) + sV_layout = sm100_utils_basic.make_smem_layout_b( + tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage, + ) + sO_layout = sm100_utils_basic.make_smem_layout_epi( + self.o_dtype, self.o_layout, self.epi_tile, self.epi_stage, + ) + if const_expr(not self.same_hdim_kv_padded): + # sK and sV are using the same physical smem so we need to adjust the stride so that they line up + stride_sK = const_expr(max(sK_layout.outer.stride[-1], 0)) # take max to turn tuple to Int32 + stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) + stage_stride = const_expr(max(stride_sK, stride_sV) if not self.uneven_kv_smem else (stride_sK + stride_sV) // 2) + sK_layout = cute.make_composed_layout(sK_layout.inner, 0, cute.make_layout((*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride))) + sV_layout = cute.make_composed_layout(sV_layout.inner, 0, cute.make_layout((*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride))) + + if const_expr(self.pack_gqa): + shape_Q_packed = ((self.qhead_per_kvhead, mQ.shape[0]), mQ.shape[1], mK.shape[2], *mQ.shape[3:]) + stride_Q_packed = ((mQ.stride[2], mQ.stride[0]), mQ.stride[1], mQ.stride[2] * self.qhead_per_kvhead, *mQ.stride[3:]) + mQ = cute.make_tensor(mQ.iterator, cute.make_layout(shape_Q_packed, stride=stride_Q_packed)) + shape_O_packed = ((self.qhead_per_kvhead, mO.shape[0]), mK.shape[1], mK.shape[2], *mO.shape[3:]) + stride_O_packed = ((mO.stride[2], mO.stride[0]), mO.stride[1], mO.stride[2] * self.qhead_per_kvhead, *mO.stride[3:]) + mO = cute.make_tensor(mO.iterator, cute.make_layout(shape_O_packed, stride=stride_O_packed)) + if const_expr(mLSE is not None): + shape_LSE_packed = ((self.qhead_per_kvhead, mLSE.shape[0]), mK.shape[2], *mLSE.shape[2:]) + stride_LSE_packed = ((mLSE.stride[1], mLSE.stride[0]), mLSE.stride[1] * self.qhead_per_kvhead, *mLSE.stride[2:]) + mLSE = cute.make_tensor(mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)) + + # TMA load for Q + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + mQ, + cute.select(sQ_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mK, + cute.select(sK_layout, mode=[0, 1, 2]), + self.mma_tiler_qk, + tiled_mma_qk, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + mV, + cute.select(sV_layout, mode=[0, 1, 2]), + self.mma_tiler_pv, + tiled_mma_pv, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) + + # print(sO_layout.outer) + if const_expr(not self.use_tma_O): + self.epilogue_warp_ids = (14, 15) + self.empty_warp_ids = () + self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) + if const_expr(self.use_tma_O): + tma_atom_O, mO = cpasync.make_tiled_tma_atom( + tma_store_op, + mO, + cute.select(sO_layout, mode=[0, 1]), + o_cta_v_layout, + ) + gmem_tiled_copy_O = None + else: + tma_atom_O = None + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self.o_dtype.width + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self.o_dtype, num_bits_per_copy=universal_copy_bits, + ) + tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems + tO_layout = cute.make_ordered_layout( + (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), order=(1, 0), + ) + # So that we don't have to check if we overshoot kBlockM when we store O + assert self.m_block_size % tO_layout.shape[0] == 0 + vO_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) + + self.tma_copy_q_bytes = cute.size_in_bytes(self.q_dtype, cute.select(sQ_layout, mode=[0, 1, 2])) + self.tma_copy_k_bytes = cute.size_in_bytes(self.k_dtype, cute.select(sK_layout, mode=[0, 1, 2])) + self.tma_copy_v_bytes = cute.size_in_bytes(self.v_dtype, cute.select(sV_layout, mode=[0, 1, 2])) + + if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): + TileScheduler = SingleTileVarlenScheduler + else: + if const_expr(self.is_causal or self.is_local): + TileScheduler = SingleTileLPTScheduler + else: + TileScheduler = SingleTileScheduler if const_expr(not self.is_persistent) else StaticPersistentTileScheduler + tile_sched_args = TileSchedulerArguments( + cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.size(mQ.shape[2]), + cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), + cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], + mQ.shape[1], + mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 + total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=self.cta_tiler[:2], + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.k_dtype.width // 8, + is_persistent=self.is_persistent, + lpt=self.is_causal or self.is_local, + ) + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + + self.mbar_load_q_full_offset = 0 + self.mbar_load_q_empty_offset = self.mbar_load_q_full_offset + self.q_stage + self.mbar_load_kv_full_offset = self.mbar_load_q_empty_offset + self.q_stage + self.mbar_load_kv_empty_offset = self.mbar_load_kv_full_offset + self.kv_stage + self.mbar_P_full_O_rescaled_offset = self.mbar_load_kv_empty_offset + self.kv_stage + self.mbar_S_full_offset = self.mbar_P_full_O_rescaled_offset + 2 + self.mbar_O_full_offset = self.mbar_S_full_offset + 2 + self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + 2 + self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + 2 + self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.epi_stage + self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.epi_stage + self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 + self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 + self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 + self.mbar_total = self.mbar_P_full_2_offset + 2 + + sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 + + @cute.struct + class SharedStorage: + # m_barriers for pipelines + mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mbar_total] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + # store row max and row sum + sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, sO_size], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(sQ_layout)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem + cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. + # Right after this, we multiply by log2(e) before applying exp2. + # To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val + # (assigning it to softcap_val) and pre-multiply softcap_val * log2(e) + # (assigning it to softmax_scale_log2). + LOG2_E = math.log2(math.e) + if const_expr(softcap is None): + softmax_scale_log2 = softmax_scale * LOG2_E + softcap_val = None + else: + softmax_scale_log2 = softcap * LOG2_E + softcap_val = Float32(softmax_scale / softcap) + if const_expr(window_size_left is not None): + window_size_left = Int32(window_size_left) + if const_expr(window_size_right is not None): + window_size_right = Int32(window_size_right) + # Launch the kernel synchronously + self.kernel( + tma_tensor_Q, + tma_tensor_K, + tma_tensor_V, + mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + tma_atom_O, + softmax_scale_log2, + softcap_val, + window_size_left, + window_size_right, + learnable_sink, + sQ_layout, + sK_layout, + tP_layout, + sV_layout, + sO_layout, + gmem_tiled_copy_O, + tiled_mma_qk, + tiled_mma_pv, + tile_sched_params, + ).launch( + grid=grid_dim, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q + mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table + mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + tma_atom_O: Optional[cute.CopyAtom], + softmax_scale_log2: Float32, + softcap_val: Optional[Float32], + window_size_left: Optional[Int32], + window_size_right: Optional[Int32], + learnable_sink: Optional[cute.Tensor], + sQ_layout: cute.ComposedLayout, + sK_layout: cute.ComposedLayout, + tP_layout: cute.ComposedLayout, + sV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_O: Optional[cute.TiledCopy], + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tile_sched_params: ParamsBase, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V) + if const_expr(tma_atom_O is not None): + cpasync.prefetch_descriptor(tma_atom_O) + + # Alloc + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mbar_ptr = storage.mbar_ptr.data_ptr() + if warp_idx == 1: + # Init "full" barrier with number of producers, "empty" barrier with number of consumers + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_full_offset + i, len([self.load_warp_id])) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_load_q_empty_offset + i, len([self.mma_warp_id])) + if warp_idx == 2: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_empty_offset + i, cute.arch.WARP_SIZE * 4) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_softmax_corr_full_offset + i, cute.arch.WARP_SIZE * 4) + if warp_idx == 3: + if const_expr(self.s0_s1_barrier): + for i in cutlass.range_constexpr(8): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE) + if warp_idx == 4: + for i in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_full_offset + i, cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_corr_epi_empty_offset + i, cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + if warp_idx == 5: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_O_rescaled_offset + i, cute.arch.WARP_SIZE * (len(self.softmax0_warp_ids) + len(self.correction_warp_ids))) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_S_full_offset + i, len([self.mma_warp_id])) + cute.arch.mbarrier_init(mbar_ptr + self.mbar_O_full_offset + i, len([self.mma_warp_id])) + if warp_idx == 6: + for i in cutlass.range_constexpr(2): + cute.arch.mbarrier_init(mbar_ptr + self.mbar_P_full_2_offset + i, cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) + if warp_idx == 7: + cute.arch.mbarrier_init( + mbar_ptr + self.mbar_tmem_dealloc_offset, + cute.arch.WARP_SIZE + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync + pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) + # sQ_pi = storage.sQ.get_tensor(sQ_layout) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) + # sK_pi = storage.sK.get_tensor(sK_layout) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) + if const_expr(not self.overlap_sO_sQ): + sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) + else: + sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner), sO_layout.outer) + + sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) + + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + + qk_acc_shape = thr_mma_qk.partition_shape_C((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + tStS_fake = thr_mma_qk.make_fragment_C(qk_acc_shape) + # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always + # request 512 columns of tmem, so we know that it starts at 0. + tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, + assumed_align=16) + tStS = cute.make_tensor(tmem_ptr, tStS_fake.layout) + + pv_acc_shape = thr_mma_pv.partition_shape_C((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + tOtO = thr_mma_pv.make_fragment_C(pv_acc_shape) + + tStSs = tuple(cute.make_tensor(tStS.iterator + self.tmem_s_offset[stage], tStS.layout) + for stage in range(2)) + tOtOs = tuple(cute.make_tensor(tOtO.iterator + self.tmem_o_offset[stage], tOtO.layout) + for stage in range(self.q_stage)) + + tP = cute.make_tensor(tStS.iterator, tP_layout.outer) + tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] + + tOrPs = [cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p_offset[stage], + tOrP.layout, + ) for stage in range(2)] + + block_info = BlockInfo( + # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) + self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, + window_size_left, window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], + mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, self.m_block_size, self.n_block_size, + window_size_left=window_size_left, window_size_right=window_size_right, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) + + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if const_expr(len(self.empty_warp_ids) > 0): + if warp_idx == self.empty_warp_ids[0]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.load( + thr_mma_qk, + thr_mma_pv, + mQ, + mK, + mV, + sQ, + sK, + sV, + mPageTable, + tma_atom_Q, + tma_atom_K, + tma_atom_V, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + # if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + if warp_idx == self.mma_warp_id: + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.sync_warp() + + self.mma( + tiled_mma_qk, + tiled_mma_pv, + sQ, + sK, + sV, + sQ_layout.inner, + sK_layout.inner, + sV_layout.inner, + tStSs, + tOtOs, + tOrPs, + pipeline_kv, + mbar_ptr, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + + # if warp_idx == self.mma_warp_id: + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_tmem_dealloc_offset, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g(mO, sO, gmem_tiled_copy_O, tma_atom_O, mbar_ptr, SeqlenInfoCls, TileSchedulerCls) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.correction_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + softmax_loop = partial( + self.softmax_loop, + softmax_scale_log2=softmax_scale_log2, + thr_mma_qk=thr_mma_qk, + sScale=sScale, + mLSE=mLSE, + learnable_sink=learnable_sink, + mbar_ptr=mbar_ptr, + block_info=block_info, + SeqlenInfoCls=SeqlenInfoCls, + AttentionMaskCls=AttentionMaskCls, + TileSchedulerCls=TileSchedulerCls, + ) + + if const_expr(not self.s0_s1_barrier): + stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1) + softmax_loop( + stage=stage, + tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), tStS.layout)) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + else: + # If there's s0_s1_barrier, it's faster to have 2 WGs having different code + if warp_idx < self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[0], tStS.layout) + softmax_loop(stage=0, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: + tStSi = cute.make_tensor(tStS.iterator + self.tmem_s_offset[1], tStS.layout) + softmax_loop(stage=1, tStSi=tStSi) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + self.correction_loop( + thr_mma_qk, + thr_mma_pv, + tStS, + tOtOs, + sScale, + mO, + mLSE, + sO, + learnable_sink, + tma_atom_O, + mbar_ptr, + softmax_scale_log2, + block_info, + SeqlenInfoCls, + TileSchedulerCls, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) + + return + + @cute.jit + def load( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + mPageTable: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_V: cute.CopyAtom, + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + + q_producer_phase = Int32(1) + kv_producer_state = cutlass.pipeline.make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, self.kv_stage) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mQ_cur = mQ[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mQ_cur = cute.domain_offset((offset, 0), mQ[None, None, head_idx]) + gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_qk, mode=[0, 2]), (None, 0)) + + head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + if const_expr(mPageTable is None): + if const_expr(not seqlen.has_cu_seqlens_k): + mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] + else: + mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) + mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) + else: + # Need to keep batch coord None since we'll index into it with page idx + mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] + gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)) + gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)) + tSgQ = thr_mma_qk.partition_A(gQ) + tSgK = thr_mma_qk.partition_B(gK) + tOgV = thr_mma_pv.partition_B(gV) + tQsQ, tQgQ = cpasync.tma_partition( + tma_atom_Q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ, 0, 3), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_atom_K, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_atom_V, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tOgV, 0, 3), + ) + + load_Q = partial( + self.load_Q, tma_atom_Q, tQgQ, tQsQ, + mbar_ptr + self.mbar_load_q_full_offset, mbar_ptr + self.mbar_load_q_empty_offset, + phase=q_producer_phase, + ) + # We have to use mbarrier directly in the load for KV instead of replying on + # pipeline_kv, because we could have different number of TMA bytes for K and V + load_K = partial( + self.load_KV, tma_atom_K, tKgK, tKsK, + mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="K", + ) + load_V = partial( + self.load_KV, tma_atom_V, tVgV, tVsV, + mbar_ptr + self.mbar_load_kv_full_offset, mbar_ptr + self.mbar_load_kv_empty_offset, + K_or_V="V", + ) + + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + page_idx = mPageTable[batch_idx, n_block_max - 1] if const_expr(mPageTable is not None) else None + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 + kv_producer_state.advance() + if const_expr(self.q_stage == 2): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 + kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + tile_scheduler.prefetch_next_work() + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def mma( + self, + tiled_mma_qk: cute.core.ThrMma, + tiled_mma_pv: cute.core.ThrMma, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + sQ_swizzle: cute.Swizzle, + sK_swizzle: cute.Swizzle, + sV_swizzle: cute.Swizzle, + tStSs: Tuple[cute.Tensor, cute.Tensor], + tOtOs: tuple[cute.Tensor], + tOrPs: Tuple[cute.Tensor, cute.Tensor], + pipeline_kv: cutlass.pipeline.PipelineAsync, + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM + thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM + tSrQ = thr_mma_qk.make_fragment_A(sQ) + tSrK = thr_mma_qk.make_fragment_B(sK) + tOrV = thr_mma_pv.make_fragment_B(sV) + if const_expr(self.q_stage == 2): + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) + else: + tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 0]) + + qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op + + gemm_Si = [ + partial( + sm100_utils.gemm_ptx_partial, + qk_mma_op, self.tmem_s_offset[stage], tSrQs[stage], sA=sQ[None, None, None, stage], + sA_swizzle=sQ_swizzle, sB_swizzle=sK_swizzle, zero_init=True + ) + for stage in range(2) + ] + gemm_Pi = [ + partial( + sm100_utils.gemm_ptx_partial, + pv_mma_op, self.tmem_o_offset[stage if self.q_stage == 2 else 0], tOrPs[stage], + sA=None, sA_swizzle=None, sB_swizzle=sV_swizzle + ) + for stage in range(2) + ] + + mma_q_consumer_phase = Int32(0) + mma_kv_consumer_state = cutlass.pipeline.make_pipeline_state( + cutlass.pipeline.PipelineUserType.Consumer, self.kv_stage + ) + P_full_O_rescaled_phase = Int32(0) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + for stage in cutlass.range_constexpr(self.q_stage): + # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # 1. wait for Q0 / Q1 + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase) + # 2. wait for K0 + if const_expr(stage == 0): + pipeline_kv.consumer_wait(mma_kv_consumer_state) + tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] + # We don't need to acquire empty S0 / S1. + # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 + # are empty. For subsequent iterations, the wait happened at the end + # of the while loop. + # 3. gemm + # sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) + sK_cur = sK[None, None, None, mma_kv_consumer_state.index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, mma_kv_consumer_state.index, mma_kv_consumer_state.phase) + gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) + # 4. release S0 / S1 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + mma_q_consumer_phase ^= 1 + # 5. release K0 + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + O_should_accumulate = False + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + mma_kv_release_state = mma_kv_consumer_state.clone() + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected O0/O1_partial and P0 / P1 + # For the first iteration in this work tile, waiting for O0/O1_partial + # means that the correction warps has finished reading tO during + # the last iteration of the previous work tile has finished. + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase= P_full_O_rescaled_phase) + # 4. release accumulated O0_partial / O1_partial + # Don't need to signal O_full to the correction warps anymore since the + # correction warps wait for the softmax warps anyway. By the time the softmax + # warps finished, S_i for the next iteration must have been done, so O_i-1 + # must have been done as well. + # with cute.arch.elect_one(): + # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 5. release V(i-1) + if const_expr(stage == 1): + pipeline_kv.consumer_release(mma_kv_release_state) + mma_kv_release_state.advance() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + if const_expr(stage == 0): + mma_kv_consumer_state.advance() + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + # 2. gemm + # Don't need to wait for the softmax warp to have finished reading the previous + # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si + # has been read and Pi has been written. + # sm100_utils.gemm(tiled_mma_qk, tStS0, tSrQs[0], tSrK[None, None, None, Ki_index], zero_init=True) + sK_cur = sK[None, None, None, Ki_index] + if const_expr(self.uneven_kv_smem): + sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) + gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) + # 3. release S0 + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_S_full_offset + stage) + # End of GEMM_QK0i (Q0 * Ki -> S0) + # 4. release Ki + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + P_full_O_rescaled_phase ^= 1 + O_should_accumulate = True + # End of seqlen_kv loop + + # release Q0 & Q1 + with cute.arch.elect_one(): + for stage in cutlass.range_constexpr(self.q_stage): + tcgen05.commit(mbar_ptr + self.mbar_load_q_empty_offset + stage) + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + pipeline_kv.consumer_wait(mma_kv_consumer_state) + Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase + tOrVi = tOrV[None, None, None, Vi_index] + for stage in cutlass.range_constexpr(2): + # 2. acquire corrected Oi_partial and Pi + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase) + # 3. gemm + # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) + # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage](tCrB=tOrVi, sB=sV_cur, zero_init=not O_should_accumulate, mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, mbar_phase=P_full_O_rescaled_phase) + # 4. release accumulated O0_partial + # We do need O_full here since for the last tile, by the time the softmax warp + # has signaled to the correction warp, the softmax warp has just finished compute + # the row sum of the current tile. It does not guarantee that the 1st tile + # of the next work tile has been computed yet. + with cute.arch.elect_one(): + tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + P_full_O_rescaled_phase ^= 1 + # 5. release Vi_end + pipeline_kv.consumer_release(mma_kv_consumer_state) + mma_kv_consumer_state.advance() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # for both softmax0 and softmax1 warp group + @cute.jit + def softmax_loop( + self, + stage: int | Int32, + softmax_scale_log2: Float32, + thr_mma_qk: cute.core.ThrMma, + tStSi: cute.Tensor, + sScale: cute.Tensor, + mLSE: Optional[cute.Tensor], + learnable_sink: Optional[cute.Tensor], + mbar_ptr: cute.Pointer, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + AttentionMaskCls: Callable, + TileSchedulerCls: Callable, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + """ + tidx = cute.arch.thread_idx()[0] % ( + cute.arch.WARP_SIZE + # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + * (len(self.softmax0_warp_ids) + ) + ) + + cS_base = cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1])) + tScS = thr_mma_qk.partition_C(cS_base) + + tStS_scale_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, 1))) + tStScale = cute.make_tensor(tStSi.iterator, tStS_scale_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tilePlikeFP32 = self.mma_tiler_qk[1] // 32 * self.v_dtype.width + tStP_layout = cute.composition(tStSi.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) + tStP = cute.make_tensor(tStSi.iterator + self.tmem_s_to_p_offset, tStP_layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32, + ) + thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi).get_slice(tidx) + tStS_t2r = thr_tmem_load.partition_S(tStSi) + + tmem_store_scale_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32, + ) + thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(tidx) + + tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) + tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScS_vec).shape + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP) + thr_tmem_store = tiled_tmem_store.get_slice(tidx) + tStP_r2t = thr_tmem_store.partition_D(tStP) + + mma_si_consumer_phase = Int32(0) + si_corr_producer_phase = Int32(1) + s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) + + # self.warp_scheduler_barrier_init() + + warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + mbar_s0_s1_sequence_offset = self.mbar_s0_s1_sequence_offset + warp_idx_in_wg + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + mask_fn = partial( + mask.apply_mask_sm100, m_block=m_block * 2 + stage, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local + ) + softmax = SoftmaxSm100(softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0) + softmax.reset() + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + ) + + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + si_corr_producer_phase ^= 1 + + # 1 masking iter + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True)) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step(mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False)) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) + # tSrScale_r2t[0] = softmax.row_sum[0] + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + + # # Write LSE to gmem + # if const_expr(mLSE is not None): + # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] + # scale = ( + # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) + # ) + # LN2 = math.log(2.0) + # lse = ( + # (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2 + # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + # ) + # if const_expr(not seqlen.has_cu_seqlens_q): + # mLSE_cur = mLSE[None, head_idx, batch_idx] + # else: + # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) + # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) + # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: + # gLSE[tidx] = lse + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def softmax_step( + self, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + n_block: Int32, + softmax: SoftmaxSm100, + mbar_ptr: cute.Pointer, + mbar_s0_s1_sequence_offset: Int32, + thr_mma_qk: cute.core.ThrMma, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + thr_tmem_store_scale: cute.CopyAtom, + tStS_t2r: cute.Tensor, + tStScale_r2t: cute.Tensor, + tStP_r2t: cute.Tensor, + sScale: cute.Tensor, + stage: int | Int32, + mask_fn: Optional[Callable] = None, + is_first: bool = False, + ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tScS_t2r_shape = thr_tmem_load.partition_D(tScS).shape + + # Wait for Si + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_S_full_offset + stage, mma_si_consumer_phase) + tSrS_t2r = cute.make_fragment(tScS_t2r_shape, self.qk_acc_dtype) + cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block) + row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) + + if const_expr(not is_first): + # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScS_vec).shape, Float32) + # tSrScale_r2t[0] = acc_scale + # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) + # cute.arch.fence_view_async_tmem_store() + thread_idx = thr_tmem_load.thr_idx + sScale[thread_idx + stage * self.m_block_size] = acc_scale + # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) + # Notify correction wg that row_max is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + + # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) + # print(tSrS_t2r) + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + # Sequence barrier wait + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_wait(mbar_ptr + mbar_s0_s1_sequence_offset + stage * 4, s0_s1_sequence_phase) + tSrP_r2t_f32 = cute.make_fragment(thr_tmem_store.partition_S(tScP).shape, Float32) + tSrP_r2t = cute.make_tensor( + cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout, + ) + # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) + softmax.apply_exp2_convert(tSrS_t2r, tSrP_r2t, e2e=mask_fn is None and self.head_dim_padded <= 128, + e2e_freq=self.e2e_freq) + # Sequence barrier arrive + if const_expr(self.s0_s1_barrier): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_s0_s1_sequence_offset + (1 - stage) * 4) + # print(tSrP_r2t_f32, tStP_r2t) + # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2]) // 4 * 3, cute.size(tStP_r2t.shape[2])): + cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) + cute.arch.fence_view_async_tmem_store() + # Notify mma warp that the 2nd half of P is ready + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_2_offset + stage) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase) + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + # acc_scale = cute.arch.exp2(acc_scale_) + return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 + + @cute.jit + def correction_loop( + self, + thr_mma_qk: cute.core.ThrMma, + thr_mma_pv: cute.core.ThrMma, + tStS: cute.Tensor, + tOtOs: tuple[cute.Tensor], + sScale: cute.Tensor, + mO: cute.Tensor, + mLSE: cute.Tensor, + sO: cute.Tensor, + learnable_sink: Optional[cute.Tensor], + tma_atom_O: cute.CopyAtom, + mbar_ptr: cute.Pointer, + softmax_scale_log2: Float32, + block_info: BlockInfo, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + tScS = thr_mma_qk.partition_C(cute.make_identity_tensor((self.mma_tiler_qk[0], self.mma_tiler_qk[1]))) + tStS_scale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) + tStScales = tuple(cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStS_scale_layout) + for stage in range(2)) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((self.m_block_size, 1))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype, + ) + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]) + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(tidx) + + tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(2)] + tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScS_vec).shape + + # First iter: no correction is required + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 0) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) + + softmax_corr_consumer_phase = Int32(0) + o_corr_consumer_phase = Int32(0) + corr_epi_producer_phase = Int32(1) + + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + + # Ignore first signal from softmax as no correction is required + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 0) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + 1, softmax_corr_consumer_phase) + softmax_corr_consumer_phase ^= 1 + + tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) + for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for stage in cutlass.range_constexpr(2): + # wait for S0 / S1 + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + scale = sScale[tidx + stage * self.m_block_size] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + # should_rescale = True + # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) + # Don't need O_full anymore, since by the time softmax has signaled the correction + # warps, S_i must have been done, so O_i-1 must have been done as well. + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + if should_rescale: + self.correction_rescale( + thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + (1 - stage)) + softmax_corr_consumer_phase ^= 1 + # o_corr_consumer_phase ^= 1 + # End of seqlen_corr_loop_steps + + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + + # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without + # additional sync because the MMA in the top half must have been done. + # Similarly we can write to stage 1 of sO without additional sync. + stats = [None] * self.q_stage + learnable_sink_val = [None] * self.q_stage + if const_expr(learnable_sink is not None): + if const_expr(not self.pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + learnable_sink_val = [sink_val] * self.q_stage + else: # Each thread might have a different sink value due to different q_head + for stage in cutlass.range_constexpr(self.q_stage): + q_head_idx = ((self.q_stage * m_block + stage) * self.m_block_size + tidx) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead + learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_softmax_corr_full_offset + stage, softmax_corr_consumer_phase) + # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) + # cute.arch.fence_view_async_tmem_load() + # scale = tSrScale_t2r[0] + row_sum = sScale[tidx + stage * self.m_block_size] + if const_expr(mLSE is not None or learnable_sink is not None): + row_max = sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] + else: + row_max = None + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + if const_expr(learnable_sink is not None): + LOG2_E = math.log2(math.e) + row_sum += utils.exp2f(learnable_sink_val[stage] * LOG2_E - row_max * softmax_scale_log2) + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase) + self.correction_epilogue( + thr_mma_pv, tOtOs[stage], tidx, scale, sO[None, None, stage], + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + # Signal for the next work tile that O buffers in tmem are already read, so + # mma warp can write to them + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) + # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) + for stage in cutlass.range_constexpr(self.q_stage): + gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)) + row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] + # if tidx == 0 and stage <= 1: + # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2 + if not acc_O_mn_row_is_zero_or_nan else -Float32.inf + ) + seqlen_q = seqlen.seqlen_q if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead + if tidx < seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: + # This actually just works with PackGQA too + gLSE[tidx] = lse + + o_corr_consumer_phase ^= 1 + softmax_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) + # gO = gO_qdhb[None, None, None, head_idx, batch_idx] + # tOsO, tOgO = cpasync.tma_partition( + # tma_atom_O, + # 0, + # cute.make_layout(1), + # cute.group_modes(sO, 0, 2), + # cute.group_modes(gO, 0, 2), + # ) + # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 + # stage = warp_idx_in_wg + # if stage < self.q_stage: + # # wait from corr, issue tma store on smem + # # 1. wait for O0 / O1 final + # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) + # # 2. copy O0 / O1 to gmem + # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) + # cute.arch.cp_async_bulk_commit_group() + # # Ensure O0 / O1 buffer is ready to be released + # cute.arch.cp_async_bulk_wait_group(0, read=True) + # cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: Int32, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + """ + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((self.m_block_size, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) + tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape + tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) + + frg_count = self.head_dim_v_padded // corr_tile_size + tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) + for i in cutlass.range_constexpr(frg_count): + tOrO_frg_i = tOrO_frg[None, i] + tTMrO_i_layout = cute.composition(tOrO_frg_i.layout, cute.make_layout(tOrO_frg.shape[0])) + tTMrO_i = cute.make_tensor(tOrO_frg_i.iterator, tTMrO_i_layout) + tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tTMrO_i) + for j in cutlass.range_constexpr(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), (scale, scale), + ) + tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) + cute.copy(tiled_tmem_store, tTMrO_i, tOtO_r2t_i) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def correction_epilogue( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + thread_idx: Int32, + scale: Float32, + sO: cute.Tensor, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilogue function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + cO = cute.make_identity_tensor((self.mma_tiler_pv[0], self.mma_tiler_pv[1])) + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size))) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( + self.mma_tiler_pv, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils_basic.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy( + smem_copy_atom, + layout_tv=tiled_tmem_load.layout_dst_tv_tiled, + tiler_mn=tiled_tmem_load.tiler_mn, + ) + + tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tOsO_s2r = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in cutlass.range_constexpr(self.head_dim_v_padded // corr_tile_size): + tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] + tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] + tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) + cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) + for j in cutlass.range_constexpr(0, cute.size(tOrO_frg), 2): + tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale), + ) + tSMrO = cute.make_fragment(tOrO_frg.shape, self.o_dtype) + o_vec = tOrO_frg.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tOsO_r2s_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta, + ) + + @cute.jit + def epilogue_s2g( + self, + mO: cute.Tensor, + sO: cute.Tensor, + gmem_tiled_copy_O: cute.TiledCopy, + tma_atom_O: Optional[cute.CopyAtom], + mbar_ptr: cute.Pointer, + SeqlenInfoCls: Callable, + TileSchedulerCls: Callable, + ): + epi_consumer_phase = Int32(0) + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + m_block, head_idx, batch_idx = work_tile.tile_idx + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(not seqlen.has_cu_seqlens_q): + mO_cur = mO[None, None, head_idx, batch_idx] + else: + offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + mO_cur = cute.domain_offset((offset, 0), mO[None, None, head_idx]) + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + if const_expr(self.use_tma_O): + tOsO, tOgO = cpasync.tma_partition( + tma_atom_O, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + # 2. copy O0 / O1 to gmem + cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) + cute.arch.cp_async_bulk_commit_group() + for stage in cutlass.range_constexpr(self.q_stage): + # Ensure O0 / O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1 - stage, read=True) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + else: + tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA(self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead) + for stage in cutlass.range_constexpr(self.q_stage): + # wait from corr, issue tma store on smem + # 1. wait for O0 / O1 final + cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, epi_consumer_phase) + # 2. copy O0 / O1 to gmem + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO[None, None, None, 0], self.o_dtype) + cute.autovec_copy(tOsO[None, None, None, stage], tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size - tOcO[0][0]: + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] if self.check_hdim_v_oob else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, self.q_stage * m_block + stage, seqlen.seqlen_q) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) + + # Advance to next tile + epi_consumer_phase ^= 1 + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + def load_Q( + self, + tma_atom: cute.CopyAtom, + tQgQ: cute.Tensor, + tQsQ: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + stage: int, + phase: Int32, + ): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, self.tma_copy_q_bytes) + cute.copy( + tma_atom, tQgQ[None, block], tQsQ[None, stage], tma_bar_ptr=mbar_full_ptr + stage + ) + + @cute.jit + def load_KV( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + mbar_full_ptr: cute.Pointer, + mbar_empty_ptr: cute.Pointer, + block: Int32, + producer_state: cutlass.pipeline.PipelineState, + K_or_V: str, + page_idx: Optional[Int32] = None, + ): + assert K_or_V in ("K", "V") + tma_copy_bytes = self.tma_copy_k_bytes if const_expr(K_or_V == "K") else self.tma_copy_v_bytes + stage, phase = producer_state.index, producer_state.phase + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, phase) + if const_expr(K_or_V == "K" and self.uneven_kv_smem): + # Before this round, the smem location was occupied by V, which is smaller than + # K. So we need to wait for the stage after that (stage 1) to be empty as well. + if stage == 0: + cute.arch.mbarrier_wait(mbar_empty_ptr + 1, phase) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(mbar_full_ptr + stage, tma_copy_bytes) + tXsX_cur = tXsX[None, stage] + if const_expr(self.uneven_kv_smem): + # Since this is the producer_state, the phase starts at 1, so we have to invert it + tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) + # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 + tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] + cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=mbar_full_ptr + stage) + + @cute.jit + def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): + if const_expr(self.uneven_kv_smem): + # smem layout is [smem_large, smem_small, smem_large], and the current stride is + # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if + # phase == 0, or left by offset if phase == 1. + offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) + return cute.make_tensor(sX.iterator + offset, sX.layout) + else: + return sX + + def make_and_init_load_kv_pipeline(self, load_kv_mbar_ptr): + load_kv_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) + ) + load_kv_consumer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])) + return cutlass.pipeline.PipelineTmaUmma.create( + barrier_storage=load_kv_mbar_ptr, + num_stages=self.kv_stage, + producer_group=load_kv_producer_group, + consumer_group=load_kv_consumer_group, + tx_count=self.tma_copy_k_bytes, + ) + + # @cute.jit + # def warp_scheduler_barrier_init(self): + # warp_group_idx = utils.canonical_warp_group_idx(sync=False) + # if warp_group_idx == 0: + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * 128, + # ) + + # def warp_scheduler_barrier_sync(self): + # cute.arch.barrier( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), + # number_of_threads=2 * 128 + # ) + + # def warp_scheduler_barrier_arrive(self): + # cur_wg = utils.canonical_warp_group_idx(sync=False) + # next_wg = 1 - cur_wg + # cute.arch.barrier_arrive( + # barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, + # ) diff --git a/flash_attn/cute/hopper_helpers.py b/flash_attn/cute/hopper_helpers.py new file mode 100644 index 0000000000..3a57e43da0 --- /dev/null +++ b/flash_attn/cute/hopper_helpers.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025, Tri Dao. +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import warpgroup + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: cutlass.Constexpr[bool] = False, + wg_wait: cutlass.Constexpr[int] = 0, + # A_in_regs: cutlass.Constexpr[bool] = False, + swap_AB: cutlass.Constexpr[bool] = False, +) -> None: + if cutlass.const_expr(swap_AB): + gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) + else: + warpgroup.fence() + # We make a new mma_atom since we'll be modifying its attribute (accumulate). + # Otherwise the compiler complains "operand #0 does not dominate this use" + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + mma_atom.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + if cutlass.const_expr(wg_wait >= 0): + warpgroup.wait_group(wg_wait) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py new file mode 100644 index 0000000000..b02d1e91be --- /dev/null +++ b/flash_attn/cute/interface.py @@ -0,0 +1,800 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. + +# Supported features: +# - BF16 & FP16 dtype +# - noncausal & causal attention +# - MHA, GQA, MQA +# - hdim 64, 96, 128. +# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) +# - varlen +# - sliding window +# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) + +# Features not supported yet: +# - split (i.e. FlashDecoding) +# - tuned block sizes +# - paged KV +# - append KV to existing KV cache +# - FP8 +# - bwd pass optimized for Hopper/Blackwell + +import math +from typing import Optional, Tuple + +import torch + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + +from flash_attn.cute import utils +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80, FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +def _flash_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + learnable_sink: Optional[torch.Tensor] = None, + # m_block_size: int = 128, + # n_block_size: int = 64, + # num_threads: int = 128, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 384, + pack_gqa: Optional[bool] = None, + _compute_capability: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + num_head, head_dim = q.shape[-2:] + if cu_seqlens_q is None: + batch_size, seqlen_q = q.shape[:2] + total_q = batch_size * seqlen_q + else: + batch_size = cu_seqlens_q.shape[0] - 1 + seqlen_q = None + total_q = q.shape[0] + if page_table is not None: + assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" + assert page_table.dtype == torch.int32, "page_table must be int32" + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" + max_num_pages_per_seq = page_table.shape[1] + assert page_table.shape == (batch_size, max_num_pages_per_seq) + num_pages, page_size = k.shape[:2] + seqlen_k = num_pages * page_size + else: + num_pages, page_size = None, None + seqlen_k = k.shape[-3] + num_head_kv = k.shape[-2] + head_dim_v = v.shape[-1] + if cu_seqlens_k is None: + if page_table is None: + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + else: + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) + else: + assert k.shape == (seqlen_k, num_head_kv, head_dim) + assert v.shape == (seqlen_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == (batch_size + 1,), "cu_seqlens_k must have shape (batch_size + 1,)" + if cu_seqlens_q is not None: + assert cu_seqlens_q.shape == (batch_size + 1,), "cu_seqlens_q must have shape (batch_size + 1,)" + assert seqused_q is None or seqused_q.shape == (batch_size,), "seqused_q must have shape (batch_size,)" + assert seqused_k is None or seqused_k.shape == (batch_size,), "seqused_k must have shape (batch_size,)" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" + for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: + if t is not None: + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" + assert t.stride(0) == 1, "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" + if learnable_sink is not None: + assert learnable_sink.shape == (num_head,) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" + assert all(t is None or t.is_cuda for t in (q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink)), "inputs must be on CUDA device" + assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + if softcap == 0.0: + softcap = None + qhead_per_kvhead = num_head // num_head_kv + if pack_gqa is None: + pack_gqa = qhead_per_kvhead > 1 + + out_torch_dtype = q.dtype + device = q.device + q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) + out = torch.empty(*q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device) + lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad else None + + dtype = torch2cute_dtype_map[q.dtype] + q_tensor, k_tensor, v_tensor, o_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out) + ] + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) if lse is not None else None + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) + ] + page_table_tensor = from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) if page_table is not None else None + if causal: + window_size_right = 0 + local = window_size_left is not None or window_size_right is not None + if window_size_left is not None or window_size_right is not None: + if window_size_left is None and window_size_right == 0: + causal, local = True, False + else: + causal, local = False, True + compute_capability = torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if compute_capability == 9: # TODO: tune block size according to hdim + if head_dim == head_dim_v == 128 and not causal and not local: + n_block_size = 192 + if compute_capability == 10: + # TODO: fix the varlen case + if pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None): + pack_gqa = False + + compile_key = ( + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, + lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, + page_table is not None, + window_size_left is not None, window_size_right is not None, + learnable_sink is not None, + m_block_size, n_block_size, num_threads, pack_gqa, + compute_capability, + ) + if compile_key not in _flash_attn_fwd.compile_cache: + if compute_capability == 9: + assert page_table is None, "paged KV not supported on SM 9.0" + # fa_fwd = FlashAttentionForwardSm80( + fa_fwd = FlashAttentionForwardSm90( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, + # num_stages=1, + num_stages=2, + num_threads=num_threads, + Q_in_regs=False, + ) + elif compute_capability == 10: + assert page_size in [None, 128], "Only page_size=128 is supported for paged KV on SM 10.0" + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + pack_gqa=pack_gqa, + is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, + ) + else: + raise ValueError(f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x") + # TODO: check @can_implement + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, + softcap, window_size_left, window_size_right, learnable_sink_tensor, + ) + _flash_attn_fwd.compile_cache[compile_key]( + q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, + cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, + page_table_tensor, + softcap, window_size_left, window_size_right, learnable_sink_tensor, + ) + return out, lse + + +_flash_attn_fwd.compile_cache = {} + + +def _flash_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + dout: torch.Tensor, + lse: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + softcap: float = 0.0, + m_block_size: int = 64, + n_block_size: int = 128, + num_threads: int = 256, + num_stages_Q: int = 2, + num_stages_dO: int = 2, + SdP_swapAB: bool = False, + dKV_swapAB: bool = False, + dQ_swapAB: bool = False, + AtomLayoutMSdP: int = 2, + AtomLayoutNdKV: int = 2, + AtomLayoutMdQ: int = 2, + V_in_regs: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, out, dout, lse = [maybe_contiguous(t) for t in (q, k, v, out, dout, lse)] + batch_size, seqlen_q, num_head, head_dim = q.shape + _, seqlen_k, num_head_kv, _ = k.shape + _, _, _, head_dim_v = v.shape + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) + assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert lse.shape == (batch_size, num_head, seqlen_q), "lse must have shape (batch_size, num_head, seqlen_q)" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, "inputs must have the same dtype" + assert lse.dtype == torch.float32, "lse must be float32" + assert all(t.is_cuda for t in (q, k, v, out, dout, lse)), "inputs must be on CUDA device" + assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" + assert head_dim <= 256, "head_dim must be less than or equal to 256" + alignment = 16 // q.element_size() + assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" + assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + qhead_per_kvhead = num_head // num_head_kv + + device = q.device + # TODO: check if this is the right rounding + seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size + head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq_accum = torch.empty(batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device) + dpsum = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + lse_log2 = torch.empty(batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device) + if qhead_per_kvhead > 1: + seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size + head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 + dk_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device) + dv_accum = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device) + + dtype = torch2cute_dtype_map[q.dtype] + q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + for t in (q, k, v, out, dout, dq, dk, dv) + ] + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) + for t in (dq_accum, dpsum, lse_log2) + ] + if qhead_per_kvhead > 1: + dk_accum_tensor, dv_accum_tensor = [ + from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=2) + for t in (dk_accum, dv_accum) + ] + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. + compile_key_pre = (dtype, head_dim_v, m_block_size, num_threads) + if compile_key_pre not in _flash_attn_bwd.compile_cache_pre: + fa_bwd_pre = FlashAttentionBackwardPreprocess( + dtype, head_dim_v, m_block_size, num_threads=num_threads, + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_pre[compile_key_pre] = cute.compile( + fa_bwd_pre, o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, + dq_accum_tensor, current_stream + ) + _flash_attn_bwd.compile_cache_pre[compile_key_pre]( + o_tensor, do_tensor, dpsum_tensor, lse_tensor, lse_log2_tensor, dq_accum_tensor, current_stream + ) + + # Backward kernel: compute dk, dv, dq_accum. + compile_key = ( + dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap != 0.0, m_block_size, + n_block_size, num_threads, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, + AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs + ) + if compile_key not in _flash_attn_bwd.compile_cache: + fa_bwd_sm80 = FlashAttentionBackwardSm80( + dtype, + head_dim, + head_dim_v, + qhead_per_kvhead, + m_block_size, + n_block_size, + num_stages_Q, + num_stages_dO, + num_threads, + causal, + SdP_swapAB, + dKV_swapAB, + dQ_swapAB, + AtomLayoutMSdP, + AtomLayoutNdKV, + AtomLayoutMdQ, + V_in_regs=V_in_regs, + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache[compile_key] = cute.compile( + fa_bwd_sm80, q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + dq_accum_tensor, + dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, + dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + softmax_scale, current_stream + ) + _flash_attn_bwd.compile_cache[compile_key]( + q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, + dq_accum_tensor, + dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, + dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + softmax_scale, current_stream + ) + + # Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16 + compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, dq_accum_tensor, dq_tensor, softmax_scale, current_stream + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dq_accum_tensor, dq_tensor, softmax_scale, current_stream + ) + + if qhead_per_kvhead > 1: + # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 + compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, dk_accum_tensor, dk_tensor, softmax_scale, current_stream + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dk_accum_tensor, dk_tensor, softmax_scale, current_stream + ) + compile_key_post = (dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) + if compile_key_post not in _flash_attn_bwd.compile_cache_post: + fa_bwd_post = FlashAttentionBackwardPostprocess( + dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB + ) + # TODO: check @can_implement + _flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile( + fa_bwd_post, dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + ) + _flash_attn_bwd.compile_cache_post[compile_key_post]( + dv_accum_tensor, dv_tensor, cutlass.Float32(1.0), current_stream + ) + + return dq, dk, dv + + +_flash_attn_bwd.compile_cache_pre = {} +_flash_attn_bwd.compile_cache = {} +_flash_attn_bwd.compile_cache_post = {} + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + pack_gqa: Optional[bool] = None, + ): + out, lse = _flash_attn_fwd( + q, + k, + v, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, + softcap=softcap, + pack_gqa=pack_gqa, + ) + ctx.save_for_backward(q, k, v, out, lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + return out, lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, lse = ctx.saved_tensors + dq, dk, dv = _flash_attn_bwd( + q, + k, + v, + out, + dout, + lse, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + ) + return dq, dk, dv, *((None,) * 5) + + +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + pack_gqa: Optional[bool] = None, + ): + out, lse = _flash_attn_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + learnable_sink=learnable_sink, + softcap=softcap, + pack_gqa=pack_gqa, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + return out, lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + raise NotImplementedError( + "Backward pass for FlashAttention with variable length sequences is not implemented yet." + ) + + +def flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + pack_gqa: Optional[bool] = None, +): + return FlashAttnFunc.apply( + q, + k, + v, + softmax_scale, + causal, + window_size, + learnable_sink, + softcap, + pack_gqa, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[Optional[int], Optional[int]] = (None, None), + learnable_sink: Optional[torch.Tensor] = None, + softcap: float = 0.0, + pack_gqa: Optional[bool] = None, +): + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + softmax_scale, + causal, + window_size, + learnable_sink, + softcap, + pack_gqa, + ) + + +def _flash_attn_fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: torch.Tensor, + lse: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + semaphore_to_reset: Optional[torch.Tensor] = None, +) -> None: + """Forward combine kernel for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. + + Args: + out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or + (num_splits, total_q, nheads, headdim) if there's cu_seqlens + lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or + (num_splits, total_q, nheads) if there's cu_seqlens + out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens + lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. + cu_seqlens: Cumulative sequence lengths for variable length sequences + seqused: Used sequence lengths for each batch + num_splits_dynamic_ptr: Dynamic number of splits per batch + semaphore_to_reset: Semaphore for synchronization + k_block_size: Block size for head dimension + + Returns: + None + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], "out_partial must be fp16, bf16, or fp32" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" + assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" + assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" + assert lse_partial.shape == out_partial.shape[:-1] + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + # Validate output tensor shapes and types + assert out.shape == out_partial.shape[1:], "out shape mismatch" + if lse is not None: + assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" + assert lse.dtype == torch.float32, "lse must be fp32" + + # Validate optional tensors + for t, name in [(cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr")]: + if t is not None: + assert t.dtype == torch.int32, f"{name} must be int32" + assert t.is_cuda, f"{name} must be on CUDA device" + assert t.is_contiguous(), f"{name} must be contiguous" + + head_dim = out_partial.shape[-1] + num_splits = out_partial.shape[0] + assert num_splits <= 256 + # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively + # so that kBlockM is smaller and we have more parallelism. + k_block_size = 64 if head_dim <= 64 else 128 + # We want kBlockM to be as small as possible to maximize parallelism. + # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). + m_block_size = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) + log_max_splits = max(math.ceil(math.log2(num_splits)), 4) + if m_block_size == 8: + # If kBlockM == 8 then the minimum number of splits is 32. + # TODO: we can deal w this by using 128 threads instead + log_max_splits = max(log_max_splits, 5) + + # Convert to cute tensors (using kernel-formatted tensors) + out_partial_tensor = from_dlpack(out_partial.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=4) + lse_partial_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse_partial.ndim - 2) + out_tensor = from_dlpack(out.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=3) + lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 2) if lse is not None else None + + optional_tensors = [ + from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None + for t in (cu_seqlens, seqused, num_splits_dynamic_ptr, semaphore_to_reset) + ] + cu_seqlens_tensor, seqused_tensor, num_splits_dynamic_tensor, semaphore_tensor = optional_tensors + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Create combine kernel configuration + dtype = torch2cute_dtype_map[out.dtype] + dtype_partial = torch2cute_dtype_map[out_partial.dtype] + + compile_key = ( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, + log_max_splits, + cu_seqlens is not None, seqused is not None, lse is not None, + ) + + if compile_key not in _flash_attn_fwd_combine.compile_cache: + fa_combine = FlashAttentionForwardCombine( + dtype=dtype, + dtype_partial=dtype_partial, + head_dim=head_dim, + m_block_size=m_block_size, + k_block_size=k_block_size, + log_max_splits=log_max_splits, + ) + + # Check if implementation is supported + if not fa_combine.can_implement( + dtype, dtype_partial, head_dim, m_block_size, k_block_size, log_max_splits, num_threads=256 + ): + raise RuntimeError(f"FlashAttention combine kernel cannot be implemented with given parameters") + + _flash_attn_fwd_combine.compile_cache[compile_key] = cute.compile( + fa_combine, + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream + ) + + _flash_attn_fwd_combine.compile_cache[compile_key]( + out_partial_tensor, + lse_partial_tensor, + out_tensor, + lse_tensor, + cu_seqlens_tensor, + seqused_tensor, + num_splits_dynamic_tensor, + semaphore_tensor, + current_stream + ) + + +_flash_attn_fwd_combine.compile_cache = {} + + +def flash_attn_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + return_lse: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Flash Attention combine function for split attention computation. + + Combines partial outputs and log-sum-exp values from multiple splits + of attention computation into final outputs. This is the main user-facing + interface for the combine kernel. + + Args: + out_partial: Partial outputs tensor with shape: + - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input + - (num_splits, total_q, num_heads, head_size) for variable length input + lse_partial: Partial LSE tensor with shape: + - (num_splits, batch_size, seqlen, num_heads) for regular batched input + - (num_splits, total_q, num_heads) for variable length input + out: Optional output tensor. If None, will be created automatically. + out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. + return_lse: Whether to return the combined LSE tensor. Default is True. + + Returns: + Tuple of (out, lse) where: + - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) + or (total_q, num_heads, head_size) for varlen + - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) + or (total_q, num_heads) for varlen. None if return_lse=False + + Note: + This function expects the input tensors to be in the format produced by + split attention computation, where the first dimension is num_splits. + The permuting from user format to kernel format is now done inside the kernel. + """ + # Input validation + assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" + assert lse_partial.dim() in [3, 4], "lse_partial must have 3 or 4 dimensions" + assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + + # Determine if this is variable length based on dimensions + is_varlen = out_partial.dim() == 4 + + if is_varlen: + # Variable length: (num_splits, total_q, num_heads, head_size) + num_splits, total_q, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, total_q, num_heads), "lse_partial shape mismatch for varlen" + batch_size = 1 # Treat as single batch for varlen + seqlen = total_q + else: + # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) + num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), "lse_partial shape mismatch" + + # Determine output dtype + if out_dtype is None: + out_dtype = out_partial.dtype + + # Create output if not provided + device = out_partial.device + if out is None: + if is_varlen: + out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) + else: + out = torch.empty(batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device) + + # Create lse output only if requested + if return_lse: + if is_varlen: + lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device).transpose(0, 1) + else: + lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device).transpose(1, 2) + else: + lse = None + + _flash_attn_fwd_combine(out_partial, lse_partial, out, lse) + return out, lse diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py new file mode 100644 index 0000000000..28c019db7b --- /dev/null +++ b/flash_attn/cute/mask.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute + +import flash_attn.cute.utils as utils + + +@dataclass(frozen=True) +class AttentionMask: + m_block_size: cutlass.Constexpr[int] + n_block_size: cutlass.Constexpr[int] + seqlen_q: cutlass.Int32 + seqlen_k: cutlass.Int32 + window_size_left: Optional[cutlass.Int32] = None + window_size_right: Optional[cutlass.Int32] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA + + @cute.jit + def apply_mask( + self, + acc_S: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + thr_mma: cute.TiledMma, + mask_seqlen: cutlass.Constexpr[bool], + mask_causal: cutlass.Constexpr[bool], + mask_local: cutlass.Constexpr[bool] = False, + ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS)) + # We use t0ScS as these indices are known at compile time. We then must subtract the + # column limit by the thread column offset. + t0ScS_mn = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cS)) + thr_col_offset = tScS_mn[0][1] + seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size - thr_col_offset + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): + if cutlass.const_expr(False): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + oob = t0ScS_mn[0, c][1] >= seqlenk_col_limit + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = -cutlass.Float32.inf if oob else acc_S_mn[r, c] + else: # R2P trick, see apply_mask_sm100 + # Instead of comparing limit to 0, 1, 8, 9, 16, 17, ..., + # we compare a transformed version of limit to 0, 1, 2, 3, 4, 5, ... + # This is so that we can use the R2P instruction. + col_limit_transformed = seqlenk_col_limit // 8 * 2 + min(seqlenk_col_limit % 8, 2) + ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + mask = (1 << col_limit_right_s) - 1 + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf + else: # Causal or local + # If PackGQA, we split the work of compute divmod among threads in the same row + threads_per_row = thr_mma.tv_layout_C.shape[0][0] + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + assert cute.arch.WARP_SIZE % threads_per_row == 0, ( + "threads_per_row must divide WARP_SIZE" + ) + assert cute.size(acc_S_mn.shape[0]) <= threads_per_row + tidx = thr_mma.thr_idx + mma_m_idx = ( + m_block * self.m_block_size + tScS_mn[tidx % threads_per_row, 0][0] + ) // self.qhead_per_kvhead_packgqa + causal_row_offset = ( + 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q - thr_col_offset + ) + if cutlass.const_expr(mask_causal): + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + if cutlass.const_expr(False): + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + acc_S_mn[r, c] = -cutlass.Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] + else: # R2P trick, see apply_mask_sm100 + col_limit_transformed = col_limit_right // 8 * 2 + min(col_limit_right % 8, 2) + ncol = cutlass.const_expr(cute.size(tScS_mn.shape[1])) + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_transformed - s * 24, 0) + mask = (1 << col_limit_right_s) - 1 + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + in_bound = cutlass.Boolean(mask & (1 << i)) + c = s * 24 + i + acc_S_mn[r, c] = acc_S_mn[r, c] if in_bound else -cutlass.Float32.inf + else: # Local + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if cutlass.const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if cutlass.const_expr(self.window_size_left is not None) + else None + ) + for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): + if cutlass.const_expr(self.qhead_per_kvhead_packgqa == 1): + row_idx = tScS_mn[r, 0][0] + m_block * self.m_block_size + else: + row_idx = utils.shuffle_sync( + mma_m_idx, r % threads_per_row, width=threads_per_row + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 + ) + # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block = {}, r = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", n_block, r, row_idx, causal_row_offset, col_limit_right, col_limit_left) + # traverse column index. + for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): + col_idx = t0ScS_mn[0, c][1] + # only consider the column index, so the row index sets to 0. + if col_idx >= col_limit_right or col_idx < col_limit_left: + acc_S_mn[r, c] = -cutlass.Float32.inf + + @cute.jit + def apply_mask_sm100( + self, + acc_S: cute.Tensor, + m_block: cutlass.Int32, + n_block: cutlass.Int32, + thr_mma: cute.TiledMma, + thr_tmem_load: cute.TiledCopy, + mask_seqlen: cutlass.Constexpr, + mask_causal: cutlass.Constexpr, + mask_local: cutlass.Constexpr, + ) -> None: + assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" + cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) + tScS = thr_mma.partition_C(cS) + tScS_t2r = thr_tmem_load.partition_D(tScS) + seqlenk_col_limit = self.seqlen_k - n_block * self.n_block_size + if cutlass.const_expr(not mask_causal and not mask_local): + if cutlass.const_expr(mask_seqlen): + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(False): + for i in cutlass.range(ncol, unroll_full=True): + # if tScS_t2r[i][1] >= seqlenk_col_limit: + # acc_S[i] = -cutlass.Float32.inf + # For some reason the 2 lines above generate really bad SASS + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + # Ideally we'd move by 32 instead of 24, but mask >> i isn't correct for i == 31 + # (see below). + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + # Don't need to clamp to 32 since the shr.u32 instruction does that already + col_limit_right_s = max(seqlenk_col_limit - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << col_limit_right_s) - 1 + # if tidx == 0: cute.printf("mask = 0x%x, col_limit_right_s = %d, col_limit_right_s = %d", mask, col_limit_right_s, col_limit_right_s) + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + # mask >> i does not produce correct result for 0b11..11 >> 31 + # However, if we use utils.shr_u32, the compiler doesn't generate + # the R2P instruction, so it's slower. + # Instead we just move by 24 instead of 32. + # if tidx == 0: cute.printf("mask_i_bit = %d, after shift = 0x%x, i = %d, s = %d", mask_i_bit, utils.shr_u32(mask, i), i, s) + acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf + # This is the equivalent of: + # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + # if tidx == 0: cute.print_tensor(acc_S) + else: # Causal or local + causal_row_offset = 1 + self.seqlen_k - n_block * self.n_block_size - self.seqlen_q + row_idx = tScS_t2r[0][0] + m_block * self.m_block_size + if cutlass.const_expr(self.qhead_per_kvhead_packgqa != 1): + row_idx = row_idx // self.qhead_per_kvhead_packgqa + if cutlass.const_expr(mask_causal): + col_limit_right = row_idx + causal_row_offset + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + # if cute.arch.thread_idx()[0] % 32 == 0: + # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) + ncol = cutlass.const_expr(cute.size(tScS_t2r.shape)) + if cutlass.const_expr(False): + for i in cutlass.range(ncol, unroll_full=True): + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] + ) + else: + # Bit manipulation, compiles down to the R2P instruction + # We know that tScS_t2r[i][1] == i, for the particular tmem copy atom we're using + for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)): + col_limit_right_s = max(col_limit_right - s * 24, 0) + # 0 -> 0b00...00, 1 -> 0b00...01, ..., 31 -> 0b01...11, 32 -> 0b11...11 + mask = (1 << col_limit_right_s) - 1 + # This needs to be range_constexpr, otherwise the compiler can't generate + # the R2P instruction + for i in cutlass.range_constexpr(min(24, ncol - s * 24)): + acc_S[s * 24 + i] = acc_S[s * 24 + i] if cutlass.Boolean(mask & (1 << i)) else -cutlass.Float32.inf + # This is the equivalent of: + # acc_S[s * 24 + i] = acc_S[s * 24 + i] if col_limit_right_s <= i else -cutlass.Float32.inf + else: + local_row_offset_right = ( + causal_row_offset + self.window_size_right + if cutlass.const_expr(self.window_size_right is not None) + else None + ) + local_row_offset_left = ( + causal_row_offset - 1 - self.window_size_left + if cutlass.const_expr(self.window_size_left is not None) + else None + ) + if cutlass.const_expr(self.window_size_right is not None): + col_limit_right = row_idx + local_row_offset_right + if cutlass.const_expr(mask_seqlen): + col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) + else: + col_limit_right = self.n_block_size + col_limit_left = ( + row_idx + local_row_offset_left if cutlass.const_expr(self.window_size_left is not None) else 0 + ) + # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) + for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): + col_idx = tScS_t2r[i][1] + acc_S[i] = ( + -cutlass.Float32.inf + if col_idx >= col_limit_right or col_idx < col_limit_left + else acc_S[i] + ) diff --git a/flash_attn/cute/mma_sm100_desc.py b/flash_attn/cute/mma_sm100_desc.py new file mode 100644 index 0000000000..62f1bc742e --- /dev/null +++ b/flash_attn/cute/mma_sm100_desc.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025, Tri Dao. +# Ported Cutlass code from C++ to Python: +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp +# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp + +from enum import IntEnum + +import cutlass +import cutlass.cute as cute + +# --------------------------------------------------------------------------- +# Enumerations that match the HW encodings (values MUST stay identical) +# --------------------------------------------------------------------------- + + +class Major(IntEnum): # matrix “layout” in the ISA docs + K = 0 + MN = 1 + + +class ScaleIn(IntEnum): # negate flags + One = 0 + Neg = 1 + + +class Saturate(IntEnum): + False_ = 0 + True_ = 1 + + +class CFormat(IntEnum): # 2-bit field (bits 4-5) + F16 = 0 + F32 = 1 + S32 = 2 + + +class F16F32Format(IntEnum): # 3-bit field (A/B element type) + F16 = 0 + BF16 = 1 + TF32 = 2 + + +class S8Format(IntEnum): + UINT8 = 0 + INT8 = 1 + + +class MXF8F6F4Format(IntEnum): + E4M3 = 0 + E5M2 = 1 + E2M3 = 3 + E3M2 = 4 + E2M1 = 5 + + +class MaxShift(IntEnum): + NoShift = 0 + MaxShift8 = 1 + MaxShift16 = 2 + MaxShift32 = 3 + + +# --------------------------------------------------------------------------- +# CUTLASS-type → encoding helpers +# --------------------------------------------------------------------------- + + +def to_UMMA_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. + """ + if cutlass_type is cutlass.Int8: + return S8Format.INT8 + # Unsigned 8-bit (if available in your CUTLASS build) + if cutlass_type is cutlass.Uint8: + return S8Format.UINT8 + # FP-16 / BF-16 + if cutlass_type is cutlass.Float16: + return F16F32Format.F16 + if cutlass_type is cutlass.BFloat16: + return F16F32Format.BF16 + # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) + if cutlass_type is cutlass.TFloat32: + return F16F32Format.TF32 + # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them + if cutlass_type is cutlass.FloatE4M3FN: + return MXF8F6F4Format.E4M3 + if cutlass_type is cutlass.FloatE5M2: + return MXF8F6F4Format.E5M2 + raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") + + +def to_C_format(cutlass_type) -> int: + """ + Map a CUTLASS scalar class to the 2-bit accumulator encoding. + """ + if cutlass_type is cutlass.Float16: + return CFormat.F16 + if cutlass_type is cutlass.Float32: + return CFormat.F32 + if cutlass_type is cutlass.Int32: + return CFormat.S32 + raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") + + +# --------------------------------------------------------------------------- +# The constructor – accepts only CUTLASS scalar classes +# --------------------------------------------------------------------------- + + +def make_instr_desc( + a_type, # CUTLASS scalar class, e.g. cutlass.Int8 + b_type, + c_type, + M: int, # 64, 128 or 256 + N: int, # 8 … 256 (multiple of 8) + a_major: Major, + b_major: Major, + a_neg: ScaleIn = ScaleIn.One, + b_neg: ScaleIn = ScaleIn.One, + c_sat: Saturate = Saturate.False_, + is_sparse: bool = False, + max_shift: MaxShift = MaxShift.NoShift, +) -> int: + """ + Build the 32-bit instruction descriptor for Blackwell MMA. + All matrix/accumulator **types must be CUTLASS scalar classes** – + passing integers is forbidden. + """ + # --- encode element formats ------------------------------------------------- + a_fmt = int(to_UMMA_format(a_type)) + b_fmt = int(to_UMMA_format(b_type)) + c_fmt = int(to_C_format(c_type)) + + # --- range checks on M/N ----------------------------------------------------- + if M not in (64, 128, 256): + raise ValueError("M must be 64, 128 or 256") + if N < 8 or N > 256 or (N & 7): + raise ValueError("N must be a multiple of 8 in the range 8…256") + + m_dim = M >> 4 # 5-bit field + n_dim = N >> 3 # 6-bit field + + # --- pack the bit-fields ----------------------------------------------------- + desc = 0 + desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) + desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag + desc |= (int(c_sat) & 0x1) << 3 # saturate + desc |= (c_fmt & 0x3) << 4 # c_format + desc |= (a_fmt & 0x7) << 7 # a_format + desc |= (b_fmt & 0x7) << 10 # b_format + desc |= (int(a_neg) & 0x1) << 13 # a_negate + desc |= (int(b_neg) & 0x1) << 14 # b_negate + desc |= (int(a_major) & 0x1) << 15 # a_major + desc |= (int(b_major) & 0x1) << 16 # b_major + desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) + desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) + desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) + + return desc & 0xFFFF_FFFF # ensure 32-bit result + + +def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): + return make_instr_desc( + op.a_dtype, + op.b_dtype, + op.acc_dtype, + op.shape_mnk[0], + op.shape_mnk[1], + Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, + ) + + +class LayoutType(IntEnum): # occupies the top-3 bits [61:64) + SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) + SWIZZLE_128B_BASE32B = 1 + SWIZZLE_128B = 2 + SWIZZLE_64B = 4 + SWIZZLE_32B = 6 + # values 3,5,7 are reserved / illegal for UMMA + + +# --------------------------------------------------------------------------- +# Helpers – figure out the SWIZZLE_* family from the tensor layout +# --------------------------------------------------------------------------- + + +def _layout_type(swizzle: cute.Swizzle) -> LayoutType: + # No idea what the right way to get B, M, S is – so we're just parsing it from the __str__ + # Swizzle string has the form "S" + swz_str = str(swizzle) + inside = swz_str[swz_str.index("<") + 1 : swz_str.index(">")] # '3,4,3' + B, M, S = [int(x) for x in inside.split(",")] # [3, 4, 3] + + if M == 4: # Swizzle<*,4,3> + if S != 3: + raise ValueError("Unexpected swizzle shift – want S==3 for M==4") + return { + 0: LayoutType.SWIZZLE_NONE, + 1: LayoutType.SWIZZLE_32B, + 2: LayoutType.SWIZZLE_64B, + 3: LayoutType.SWIZZLE_128B, + }[B] # KeyError ⇒ invalid B→ raise + if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) + if (B, S) != (2, 2): + raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") + return LayoutType.SWIZZLE_128B_BASE32B + + # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout + raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") + + +def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: + """ + Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit + smem-descriptor, without the smem start address. + layout must correspond to layout of an uint128 tensor. + """ + # ------------------------------------------------------------------ meta + layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family + + VERSION = 1 # bits 46–47 + LBO_MODE = 0 # bit 52 + BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) + + # ---------------------------------------------------------- strides (units: uint128_t = 16 B) + swizzle_atom_mn_size = { + LayoutType.SWIZZLE_NONE: 1, + LayoutType.SWIZZLE_32B: 2, + LayoutType.SWIZZLE_64B: 4, + LayoutType.SWIZZLE_128B: 8, + LayoutType.SWIZZLE_128B_BASE32B: 8, + }[layout_type] + + if major is Major.MN: + swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 + canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if stride_10 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") + stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] + if layout_type is LayoutType.SWIZZLE_NONE: + stride_byte_offset, leading_byte_offset = stride_01, stride_11 + else: + stride_byte_offset, leading_byte_offset = stride_11, stride_01 + else: + if layout_type == LayoutType.SWIZZLE_128B_BASE32B: + raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") + if not cute.size(layout.shape[0]) % 8 == 0: + raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") + canonical_layout = cute.logical_divide(layout, (8, 2)) + if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): + raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") + stride_00 = canonical_layout.stride[0][0] + if stride_00 != swizzle_atom_mn_size: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_10 = canonical_layout.stride[1][0] + if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: + raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") + stride_01 = canonical_layout.stride[0][1] + stride_byte_offset, leading_byte_offset = stride_01, stride_10 + + # ------------------------------------------------------------------ pack + desc = 0 + # leading_byte_offset_ [16:30) + desc |= (leading_byte_offset & 0x3FFF) << 16 + # stride_byte_offset_ [32:46) + desc |= (stride_byte_offset & 0x3FFF) << 32 + # version_ [46:48) + desc |= (VERSION & 0x3) << 46 + # base_offset_ [49:52) + desc |= (BASE_OFFSET & 0x7) << 49 + # lbo_mode_ [52:53) + desc |= (LBO_MODE & 0x1) << 52 + # layout_type_ [61:64) + desc |= (int(layout_type) & 0x7) << 61 + + return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width + + +def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: + # 14 bits, remove 4 LSB (bits 0-13 in desc) + return (start_addr.toint() & 0x3FFFF) >> 4 diff --git a/flash_attn/cute/named_barrier.py b/flash_attn/cute/named_barrier.py new file mode 100644 index 0000000000..99a76222bc --- /dev/null +++ b/flash_attn/cute/named_barrier.py @@ -0,0 +1,12 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import enum + + +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + WarpSchedulerWG1 = enum.auto() + WarpSchedulerWG2 = enum.auto() + WarpSchedulerWG3 = enum.auto() + PFull = enum.auto() + PEmpty = enum.auto() diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py new file mode 100644 index 0000000000..46d8dd3879 --- /dev/null +++ b/flash_attn/cute/pack_gqa.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import operator + +import cutlass +import cutlass.cute as cute + +import flash_attn.cute.utils as utils + + +class PackGQA: + def __init__( + self, + m_block_size: cutlass.Constexpr[int], + head_dim_padded: cutlass.Constexpr[int], + check_hdim_oob: cutlass.Constexpr[bool], + qhead_per_kvhead: cutlass.Constexpr[bool], + ): + self.m_block_size = m_block_size + self.head_dim_padded = head_dim_padded + self.check_hdim_oob = check_hdim_oob + self.qhead_per_kvhead = qhead_per_kvhead + + @cute.jit + def compute_ptr( + self, + tensor: cute.Tensor, + cRows: cute.Tensor, + tidx: cutlass.Int32, + block: cutlass.Int32, + threads_per_row: cutlass.Constexpr[int], + num_threads: cutlass.Constexpr[int], + ): + num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) + tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64) + for i in cutlass.range_constexpr(num_ptr_per_thread): + row = i * num_threads + cRows[tidx % threads_per_row][0] + idx = block * self.m_block_size + row + m_idx = idx // self.qhead_per_kvhead + h_idx = idx - m_idx * self.qhead_per_kvhead + tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() + return tPrPtr + + @cute.jit + def load_Q( + self, + mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + sQ: cute.Tensor, # (m_block_size, head_dim_padded) + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tQsQ = gmem_thr_copy.partition_D(sQ) + tQcQ = gmem_thr_copy.partition_S(cQ) + t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) + tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) + tQcQ_row = tQcQ[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + q_ptr_i64 = utils.shuffle_sync( + tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + q_gmem_ptr = cute.make_ptr( + mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0QcQ[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] + ): + mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tQsQ.shape[0][0]) + mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): + ki = tQcQ[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + mQ_cur_copy[None, ki], + tQsQ[None, m, k], + pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) + # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + + @cute.jit + def store_LSE( + self, + mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) + tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + thr_mma = tiled_mma.get_slice(tidx) + caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = thr_mma.partition_C(caccO) + taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0] + assert cute.size(tLSErLSE) == cute.size(taccOcO_row) + threads_per_row = tiled_mma.tv_layout_C.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + assert cute.size(tLSErLSE) <= threads_per_row + num_threads = tiled_mma.size + tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tLSErLSE)): + lse_ptr_i64 = utils.shuffle_sync( + tPrLSEPtr[m // threads_per_row], + m % threads_per_row, + width=threads_per_row, + ) + lse_gmem_ptr = cute.make_ptr( + mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 + ) + row = block * self.m_block_size + taccOcO_row[m][0] + # Only the thread corresponding to column 0 writes out the lse to gmem + if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: + mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) + mLSE_copy[0] = tLSErLSE[m] + + @cute.jit + def store_O( + self, + mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) + tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy + gmem_tiled_copy: cute.TiledCopy, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + ): + gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + tOcO = gmem_thr_copy.partition_S(cO) + t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + tOcO_row = tOcO[0, None, 0] + threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] + assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" + num_threads = gmem_tiled_copy.size + tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) + for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + o_ptr_i64 = utils.shuffle_sync( + tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row + ) + o_gmem_ptr = cute.make_ptr( + mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + if ( + t0OcO[0, m, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] + ): + mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) + elems_per_load = cute.size(tOrO.shape[0][0]) + mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) + for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): + ki = tOcO[0, 0, k][1] // elems_per_load + cute.copy( + gmem_thr_copy, + tOrO[None, m, k], + mO_cur_copy[None, ki], + pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, + ) diff --git a/flash_attn/cute/pipeline.py b/flash_attn/cute/pipeline.py new file mode 100644 index 0000000000..7ea4743c2e --- /dev/null +++ b/flash_attn/cute/pipeline.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025, Tri Dao. + +# import math +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, if_generate +from cutlass.pipeline import PipelineAsync, PipelineState, CooperativeGroup, pipeline_init_wait +from cutlass.pipeline import PipelineUserType, PipelineOp + + +class PipelineStateSimple: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + Use a single Int32 to store both the index and phase bit, then we use divmod to get the + index and phase. If stages is a power of 2, divmod turns into bit twiddling. + """ + + def __init__(self, stages: int, phase_index: Int32): + # assert stages < 2**16 + # self._log_stages = int(math.log2(stages)) + # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2." + self._stages = stages + self._phase_index = phase_index + + def clone(self) -> "PipelineStateSimple": + return PipelineStateSimple(self.stages, self._phase_index) + + @property + def stages(self) -> int: + # return 1 << self._log_stages + return self._stages + + @property + def index(self) -> Int32: + # return self._phase_index & 0xFFFF + # return self._phase_index & ((1 << self._log_stages) - 1) + return self._phase_index % self._stages + + @property + def phase(self) -> Int32: + # return self._phase_index >> 16 + # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to + # take modulo 2. But in practice just passing the phase in without modulo works fine. + # return (self._phase_index >> self._log_stages) % 2 + # return self._phase_index >> self._log_stages + return self._phase_index // self._stages + + def advance(self): + self._phase_index += 1 + + # def then_body(phase_index): + # # XOR the phase bit and set the index to 0 + # return (phase_index & 0xFFFF0000) ^ (1 << 16) + + # def else_body(phase_index): + # return phase_index + + # self._phase_index = if_generate( + # (self._phase_index & 0xFFFF) == self.stages, + # then_body, + # else_body, + # [self._phase_index], + # [Int32], + # ) + + def __extract_mlir_values__(self): + phase_index = self._phase_index + return [phase_index.ir_value()] + + def __new_from_mlir_values__(self, values): + return PipelineStateSimple(self.stages, Int32(values[0])) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + # return PipelineStateSimple(stages, Int32(1 << 16)) + return PipelineStateSimple(stages, Int32(stages)) + elif type is PipelineUserType.Consumer: + return PipelineStateSimple(stages, Int32(0)) + else: + assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." + + +@dataclass(frozen=True) +class PipelineTmaAsyncNoCluster(PipelineAsync): + """ + If size(ClusterShape) == 1, PipelineTmaAsync has all threads + signaling the barrier during consumer_release. This causes a perf regression in FA3 + forward pass (especially hdim 128 causal). We instead implement a version of + PipelineTmaAsync where only 1 out of 128 threads signals the barrier. + + Assumptions: + (1) num_consumers % NumThreadsPerWarpGroup == 0 + (2) all 128 threads in the warp group are sync'ed right before calling consumer_release + """ + + @staticmethod + def create( + barrier_storage: cute.Pointer, + num_stages: Int32, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + init_wait: cutlass.Constexpr[bool] = True, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + """ + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + dst_rank = None + producer_mask = None + if cutlass.const_expr(init_wait): + pipeline_init_wait() + return PipelineTmaAsyncNoCluster( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + dst_rank, + ) + + def producer_acquire(self, state: PipelineState, try_acquire_token: Optional[Boolean] = None): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + self.sync_object_full.arrive(state.index, self.producer_mask) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. + """ + pass + + def consumer_release(self, state: PipelineState): + """ + TMA consumer release conditionally signals the empty buffer to the producer. + """ + # Only 1 thread per warp group signals the empty buffer. + if_generate( + cute.arch.thread_idx()[0] % 128 == 0, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) diff --git a/flash_attn/cute/pyproject.toml b/flash_attn/cute/pyproject.toml new file mode 100644 index 0000000000..8c4d89e52e --- /dev/null +++ b/flash_attn/cute/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "flash-attn-cute" +version = "0.1.0" +description = "Flash Attention CUTE (CUDA Template Engine) implementation" +readme = "README.md" +requires-python = ">=3.12" +license = {text = "BSD 3-Clause License"} +authors = [ + {name = "Tri Dao"}, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "nvidia-cutlass-dsl==4.1.0", + "torch", + "einops", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + +[project.urls] +Homepage = "https://github.com/Dao-AILab/flash-attention" +Repository = "https://github.com/Dao-AILab/flash-attention" + +[tool.setuptools] +packages = ["flash_attn.cute"] +package-dir = {"flash_attn.cute" = "."} + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +ignore = [ + "E731", # do not assign a lambda expression, use a def + "E741", # Do not use variables named 'I', 'O', or 'l' + "F841", # local variable is assigned to but never used +] \ No newline at end of file diff --git a/flash_attn/cute/seqlen_info.py b/flash_attn/cute/seqlen_info.py new file mode 100644 index 0000000000..dee63db6bf --- /dev/null +++ b/flash_attn/cute/seqlen_info.py @@ -0,0 +1,59 @@ +from typing import Optional + +import cutlass +import cutlass.cute as cute + +""" +This consolidates all the info related to sequence length. This is so that we can do all +the gmem reads once at the beginning of each tile, rather than having to repeat these reads +to compute various things like n_block_min, n_block_max, etc. +""" + +class SeqlenInfo: + def __init__( + self, + batch_idx: cutlass.Int32, + seqlen_static: cutlass.Int32, + cu_seqlens: Optional[cute.Tensor] = None, + seqused: Optional[cute.Tensor] = None, + ): + self.offset = 0 if cutlass.const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] + if cutlass.const_expr(seqused is not None): + self.seqlen = seqused[batch_idx] + elif cutlass.const_expr(cu_seqlens is not None): + self.seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] + else: + self.seqlen = seqlen_static + + +class SeqlenInfoQK: + def __init__( + self, + batch_idx: cutlass.Int32, + seqlen_q_static: cutlass.Int32, + seqlen_k_static: cutlass.Int32, + mCuSeqlensQ: Optional[cute.Tensor] = None, + mCuSeqlensK: Optional[cute.Tensor] = None, + mSeqUsedQ: Optional[cute.Tensor] = None, + mSeqUsedK: Optional[cute.Tensor] = None, + ): + self.offset_q = 0 if cutlass.const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] + self.offset_k = 0 if cutlass.const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] + if cutlass.const_expr(mSeqUsedQ is not None): + self.seqlen_q = mSeqUsedQ[batch_idx] + else: + self.seqlen_q = ( + seqlen_q_static + if cutlass.const_expr(mCuSeqlensQ is None) + else mCuSeqlensQ[batch_idx + 1] - self.offset_q + ) + if cutlass.const_expr(mSeqUsedK is not None): + self.seqlen_k = mSeqUsedK[batch_idx] + else: + self.seqlen_k = ( + seqlen_k_static + if cutlass.const_expr(mCuSeqlensK is None) + else mCuSeqlensK[batch_idx + 1] - self.offset_k + ) + self.has_cu_seqlens_q: int = mCuSeqlensQ is not None + self.has_cu_seqlens_k: int = mCuSeqlensK is not None diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py new file mode 100644 index 0000000000..6d8135d646 --- /dev/null +++ b/flash_attn/cute/softmax.py @@ -0,0 +1,262 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import operator +from typing import Tuple + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 + +import flash_attn.cute.utils as utils + + +class Softmax: + def __init__( + self, + scale_log2: Float32, + num_rows: cutlass.Constexpr[int], + arch: cutlass.Constexpr[int] = 80, + ): + self.scale_log2 = scale_log2 + self.row_max = cute.make_fragment(num_rows, Float32) + self.row_sum = cute.make_fragment_like(self.row_max) + self.arch = arch + + def reset(self) -> None: + self.row_max.fill(-Float32.inf) + self.row_sum.fill(0.0) + + def _compute_row_max( + self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) + + def _compute_row_sum( + self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None + ) -> Float32: + return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) + + @cute.jit + def online_softmax( + self, + acc_S: cute.Tensor, + is_first: cutlass.Constexpr[bool] = False, + check_inf: cutlass.Constexpr[bool] = True, + ) -> cute.Tensor: + """Apply online softmax and return the row_scale to rescale O. + + :param acc_S: acc_S tensor + :type acc_S: cute.Tensor + :param is_first: is first n_block + :type is_first: cutlass.Constexpr + """ + # Change acc_S to M,N layout view. + acc_S_mn = utils.make_acc_tensor_mn_view(acc_S) + row_scale = cute.make_fragment_like(self.row_max, Float32) + # Each iteration processes one row of acc_S + for r in cutlass.range(cute.size(self.row_max), unroll_full=True): + acc_S_row = acc_S_mn[r, None].load() # (n_block_size) + row_max_cur = self._compute_row_max( + acc_S_row, + init_val=self.row_max[r] if cutlass.const_expr(not is_first) else None, + ) + row_max_cur = utils.warp_reduce(row_max_cur, cute.arch.fmax, width=4) + if cutlass.const_expr(check_inf): + row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur + if cutlass.const_expr(is_first): + row_max_cur_scaled = row_max_cur * self.scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) + acc_S_row_sum = self._compute_row_sum(acc_S_row_exp) + row_scale[r] = 1.0 + else: + row_max_prev = self.row_max[r] + row_max_cur_scaled = row_max_cur * self.scale_log2 + acc_S_row_exp = utils.exp2f(acc_S_row * self.scale_log2 - row_max_cur_scaled) + # row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled) + row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * self.scale_log2) + acc_S_row_sum = ( + self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[r] * row_scale[r]) + ) + self.row_max[r] = row_max_cur + self.row_sum[r] = acc_S_row_sum + acc_S_mn[r, None].store(acc_S_row_exp) + return row_scale + + @cute.jit + def finalize(self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None) -> cute.Tensor: + """Finalize the online softmax by computing the scale and logsumexp.""" + if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): + assert cute.size(sink_val) == cute.size(self.row_sum) + # quad reduction for row_sum as we didn't do it during each iteration of online softmax + self.row_sum.store(utils.warp_reduce(self.row_sum.load(), operator.add, width=4)) + row_scale = cute.make_fragment_like(self.row_max, Float32) + for r in cutlass.range(cute.size(self.row_sum), unroll_full=True): + if cutlass.const_expr(sink_val is not None): + sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] + LOG2_E = math.log2(math.e) + self.row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - self.row_max[r] * self.scale_log2) + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 + acc_O_mn_row_is_zero_or_nan = ( + self.row_sum[r] == 0.0 or self.row_sum[r] != self.row_sum[r] + ) + row_scale[r] = ( + cute.arch.rcp_approx(self.row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) + ) * final_scale + row_sum_cur = self.row_sum[r] + LN2 = math.log(2.0) + self.row_sum[r] = ( + (self.row_max[r] * self.scale_log2 + utils.log2f(row_sum_cur)) * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + return row_scale + + @cute.jit + def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: + """Scale each row of acc_O by the given scale tensor. + :param acc_O: input tensor + :type acc_O: cute.Tensor + :param row_scale: row_scale tensor + :type row_scale: cute.Tensor + """ + acc_O_mn = utils.make_acc_tensor_mn_view(acc_O) + assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) + for r in cutlass.range(cute.size(row_scale), unroll_full=True): + acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) + + +class SoftmaxSm100(Softmax): + def __init__(self, scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0): + super().__init__(scale_log2, num_rows=1, arch=100) + self.rescale_threshold = rescale_threshold + + @cute.jit + def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = utils.exp2f(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + + def update_row_sum( + self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False + ) -> None: + init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None + # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) + self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) + # tmp = self._compute_row_sum(acc_S_row_exp) + # self.row_sum[0] = self.row_sum[0] * row_scale + tmp + + @cute.jit + def scale_subtract_rowmax( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + row_max_scaled = row_max * self.scale_log2 + for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (-row_max_scaled, -row_max_scaled), + ) + + @cute.jit + def apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + acc_S_row_converted: cute.Tensor, + e2e: cutlass.Constexpr[bool] = False, + e2e_freq: cutlass.Constexpr[int] = 16, + e2e_res: cutlass.Constexpr[int] = 4, + e2e_frg_limit: cutlass.Constexpr[int] = 1, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + if cutlass.const_expr(not e2e): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + if cutlass.const_expr(k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit): + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + else: + acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) + + @cute.jit + def scale_apply_exp2_convert( + self, + acc_S_row: cute.Tensor, + row_max: Float32, + acc_S_row_converted: cute.Tensor, + ): + assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" + minus_row_max_scaled = -row_max * self.scale_log2 + for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + (acc_S_row[i], acc_S_row[i + 1]), + (self.scale_log2, self.scale_log2), + (minus_row_max_scaled, minus_row_max_scaled), + ) + + # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): + # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( + # (acc_S_row[i], acc_S_row[i + 1]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # acc_S_row[i] = cute.arch.exp2(acc_S_row[i]) + # acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1]) + + frg_tile = 32 + assert frg_tile % 2 == 0 + frg_cnt = cute.size(acc_S_row) // frg_tile + assert cute.size(acc_S_row) % frg_tile == 0 + acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) + acc_S_row_converted_frg = cute.logical_divide( + acc_S_row_converted, cute.make_layout(frg_tile) + ) + for j in cutlass.range_constexpr(frg_cnt): + for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): + # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( + # cute.arch.fma_packed_f32x2( + # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), + # (self.scale_log2, self.scale_log2), + # (minus_row_max_scaled, minus_row_max_scaled), + # ) + # ) + # acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j]) + # acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j]) + acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j]) + acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j]) + acc_S_row_converted_frg[None, j].store( + acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) + ) diff --git a/flash_attn/cute/tile_scheduler.py b/flash_attn/cute/tile_scheduler.py new file mode 100644 index 0000000000..bea4496ecc --- /dev/null +++ b/flash_attn/cute/tile_scheduler.py @@ -0,0 +1,503 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional, Tuple +from dataclasses import dataclass, fields + +import cutlass +import cutlass.cute as cute +from cutlass import Int32 + +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import FastDivmod, clz + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, cutlass.Constexpr)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, cutlass.Constexpr)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, cutlass.Constexpr) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class TileSchedulerArguments(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + seqlen_k: Int32 + headdim: Int32 + headdim_v: Int32 + total_q: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + element_size: cutlass.Constexpr[int] = 2 + is_persistent: cutlass.Constexpr[bool] = False + lpt: cutlass.Constexpr[bool] = False + + +class SingleTileScheduler: + @dataclass + class Params(ParamsBase): + num_block: Int32 + num_head: Int32 + num_batch: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileScheduler.Params": + return SingleTileScheduler.Params(args.num_block, args.num_head, args.num_batch) + + def __init__(self, blk_coord: cute.Coord, *, loc=None, ip=None): + self._blk_coord = blk_coord + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + blk_coord = cute.arch.block_idx() + return SingleTileScheduler(blk_coord, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return params.num_block, params.num_head, params.num_batch + + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + return cutlass.utils.WorkTileInfo(self._blk_coord, self._is_first_block) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self._blk_coord]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self._blk_coord], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class StaticPersistentTileScheduler: + @dataclass + class Params(ParamsBase): + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + total_blocks: Int32 + + @staticmethod + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler.Params": + total_blocks = args.num_block * args.num_head * args.num_batch + return StaticPersistentTileScheduler.Params( + FastDivmod.create(args.num_block), FastDivmod.create(args.num_head), total_blocks + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + tile_idx = cute.arch.block_idx()[0] + return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return (cutlass.min(sm_count, params.total_blocks), Int32(1), Int32(1)) + + # @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + hn_idx, block_idx = self.params.num_block_divmod.divmod(self._tile_idx) + batch_idx, head_idx = self.params.num_head_divmod.divmod(hn_idx) + is_valid = self._tile_idx < self.params.total_blocks + # if cute.arch.thread_idx()[0] == 0: + # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) + return cutlass.utils.WorkTileInfo( + (Int32(block_idx), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + self._tile_idx += cute.arch.grid_dim()[0] + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos,): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileLPTScheduler: + @dataclass + class Params(ParamsBase): + total_blocks: Int32 + num_block_divmod: FastDivmod + num_head_divmod: FastDivmod + l2_minor_divmod: FastDivmod + l2_major_divmod: FastDivmod + l2_minor_residual_divmod: FastDivmod + num_hb_quotient: Int32 + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler.Params": + # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size + size_one_head = size_one_kv_head + size_l2 = 50 * 1024 * 1024 # 40 MB for K & V + # Swizzle is the size of each "section". Round swizzle to a power of 2 + # Need to be careful about the case where only one head will fit + # swizzle is how many heads can fit in L2 + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # Seems faster if swizzle if a power of 2 + log2_floor = lambda n: 31 - clz(n) + swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) + # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + num_hb_quotient = (args.num_head * args.num_batch) // swizzle + num_hb_remainder = (args.num_head * args.num_batch) % swizzle + return SingleTileLPTScheduler.Params( + total_blocks=args.num_block * args.num_head * args.num_batch, + num_block_divmod=FastDivmod.create(args.num_block), + num_head_divmod=FastDivmod.create(args.num_head), + l2_minor_divmod=FastDivmod.create(swizzle), + l2_major_divmod=FastDivmod.create(swizzle * args.num_block), + l2_minor_residual_divmod=FastDivmod.create( + max(num_hb_remainder, 1) + ), # don't divide by 0 + num_hb_quotient=Int32(num_hb_quotient), + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileLPTScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + return (params.total_blocks, Int32(1), Int32(1)) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + # Implement LPT scheduling coordinate calculation + bidhb, l2_mod = params.l2_major_divmod.divmod(self._tile_idx) + # If we're in the last section (called residual), we don't want to divide by + # swizzle. Instead we want to divide by the remainder. + block, bidhb_residual = 0, 0 + if bidhb < params.num_hb_quotient: + block, bidhb_residual = params.l2_minor_divmod.divmod(l2_mod) + else: + block, bidhb_residual = params.l2_minor_residual_divmod.divmod(l2_mod) + bidhb_actual = bidhb * params.l2_minor_divmod.divisor + bidhb_residual + batch_idx, head_idx = params.num_head_divmod.divmod(bidhb_actual) + # Longest-processing-time-first + block = params.num_block_divmod.divisor - 1 - block + is_valid = self._tile_idx < params.total_blocks + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._tile_idx = self.params.total_blocks + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileLPTScheduler(*(tuple(obj_list)), loc=self._loc) + + +class SingleTileVarlenScheduler: + @dataclass + class Params(ParamsBase): + num_head: Int32 + num_batch: Int32 + total_q: Int32 + max_kvblock_in_l2: Int32 + tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] + mCuSeqlensQ: Optional[cute.Tensor] = None + mSeqUsedQ: Optional[cute.Tensor] = None + qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 + lpt: cutlass.Constexpr[bool] = False + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler.Params": + size_l2 = 50 * 1024 * 1024 # 50 MB for K & V + max_kvblock_in_l2 = size_l2 // ((args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]) + assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( + "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" + ) + return SingleTileVarlenScheduler.Params( + num_head=args.num_head, + num_batch=args.num_batch, + total_q=args.total_q, + max_kvblock_in_l2=max_kvblock_in_l2, + tile_shape_mn=args.tile_shape_mn, + mCuSeqlensQ=args.mCuSeqlensQ, + mSeqUsedQ=args.mSeqUsedQ, + qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, + lpt=args.lpt, + ) + + def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): + self.params = params + self._tile_idx = tile_idx + self._is_first_block = True + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + tile_idx = cute.arch.block_idx()[0] + return SingleTileVarlenScheduler(params, tile_idx, loc=loc, ip=ip) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + total_blocks_max = ( + params.total_q + params.num_batch * (params.tile_shape_mn[0] - 1) + ) // params.tile_shape_mn[0] + return (total_blocks_max * params.num_head, Int32(1), Int32(1)) + + @cute.jit + def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: + params = self.params + batch_idx = lane + bidb_start + if cutlass.const_expr(params.mSeqUsedQ is not None): + seqlen = Int32(0) + if batch_idx < params.num_batch: + seqlen = params.mSeqUsedQ[batch_idx] + else: + assert params.mCuSeqlensQ is not None + cur_cu_seqlen = Int32(0) + if batch_idx <= params.num_batch: + cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): + seqlen *= params.qhead_per_kvhead_packgqa + return ( + cute.ceil_div(seqlen, params.tile_shape_mn[0]) + if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + lane_idx = cute.arch.lane_idx() + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + # Total number of blocks for the next 31 batches + m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) + # Same for all lanes + group_end_tile = m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) + next_tile_idx = self._tile_idx + while group_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= params.num_batch: + batch_idx = Int32(params.num_batch) + group_end_tile = next_tile_idx + 1 + else: + num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) + num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) + m_blocks_in_group = cute.arch.shuffle_sync( + num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 + ) + group_end_tile += m_blocks_in_group * params.num_head + is_valid = False + if batch_idx >= params.num_batch: + block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) + else: + group_start_tile = group_end_tile - m_blocks_in_group * params.num_head + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_group = cute.arch.popc( + cute.arch.vote_ballot_sync( + group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_group + num_m_blocks_prev_lane = ( + 0 + if batch_idx_in_group == 0 + else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) + ) + num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) + mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head + if cutlass.const_expr(params.lpt): + # This is a version of the SingleTileLPTScheduler, complicated by the fact that + # the seqlen can vary per batch. + # TODO: is there any case where num_m_blocks is 0? + # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here + num_n_blocks = num_m_blocks * params.tile_shape_mn[0] // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] + # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) + # Seems faster to have this be a power of 2 + nheads_in_l2 = 16 if num_n_blocks * 16 <= params.max_kvblock_in_l2 else (8 if num_n_blocks * 8 <= params.max_kvblock_in_l2 else (4 if num_n_blocks * 4 <= params.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1))) + nheads_in_l2 = min(nheads_in_l2, params.num_head) + mh_in_l2 = nheads_in_l2 * num_m_blocks + section_idx = mh_block // mh_in_l2 + l2_mod = mh_block - section_idx * mh_in_l2 + # Deal with tail section + nheads_in_this_section = nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= params.num_head else params.num_head - section_idx * nheads_in_l2 + block = l2_mod // nheads_in_this_section + head_idx_residual = l2_mod - block * nheads_in_this_section + head_idx = section_idx * nheads_in_l2 + head_idx_residual + block = num_m_blocks - 1 - block + else: + head_idx = mh_block // num_m_blocks + block = mh_block - head_idx * num_m_blocks + is_valid = self._is_first_block and batch_idx < params.num_batch + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) + return cutlass.utils.WorkTileInfo( + (Int32(block), Int32(head_idx), Int32(batch_idx)), is_valid + ) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def prefetch_next_work(self, *, loc=None, ip=None): + pass + + def advance_to_next_work(self, *, loc=None, ip=None): + # Single tile scheduler - set to invalid tile_idx to indicate no more work + self._is_first_block = False + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.params, self._tile_idx]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip([self.params, self._tile_idx], self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py new file mode 100644 index 0000000000..0a26fc9866 --- /dev/null +++ b/flash_attn/cute/utils.py @@ -0,0 +1,591 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Type, Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, Int32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import nvvm, llvm, arith, vector +from cutlass.cute.runtime import from_dlpack + + +def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility + ) + ) + + +def make_tiled_copy_A( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if cutlass.const_expr(swapAB): + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + + +def make_tiled_copy_B( + copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.TiledCopy: + if cutlass.const_expr(swapAB): + return cute.make_tiled_copy_A(copy_atom, tiled_mma) + else: + return cute.make_tiled_copy_B(copy_atom, tiled_mma) + + +def mma_make_fragment_A( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if cutlass.const_expr(swapAB): + return mma_make_fragment_B(smem, thr_mma) + else: + return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) + + +def mma_make_fragment_B( + smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False +) -> cute.Tensor: + if cutlass.const_expr(swapAB): + return mma_make_fragment_A(smem, thr_mma) + else: + return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric] +) -> cute.CopyAtom: + if cutlass.const_expr(arch < 90): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=2 * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=False, num_matrices=4), + element_type, + ) + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_fragment(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if cutlass.const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +@dsl_user_op +def exp2f_asm(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "ex2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32: + """exp2f calculation for both vector and scalar. + :param x: input value + :type x: cute.TensorSSA or Float32 + :return: exp2 value + :rtype: cute.TensorSSA or Float32 + """ + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + res = cute.make_fragment(x.shape, Float32) + res.store(x) + for i in cutlass.range_constexpr(cute.size(x.shape)): + res[i] = cute.arch.exp2(res[i]) + return res.load() + else: + return cute.arch.exp2(x) + + +@dsl_user_op +def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "lg2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + +@dsl_user_op +def logf(a: float | Float32, *, loc=None, ip=None) -> Float32: + return log2f(a, loc=loc, ip=ip) * math.log(2.0) + + +@dsl_user_op +def fmax( + a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, + loc=loc, + ip=ip, + ) + ) + + +@cute.jit +def fmax_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + # if cutlass.const_expr(init_val is None): + # init_val = -cutlass.Float32.if + # return x.reduce(cute.ReductionOp.MAX, init_val, 0) + res = cute.make_fragment(x.shape, Float32) + res.store(x) + # local_max = [res[0], res[1]] + # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): + # local_max[0] = fmax(local_max[0], res[i + 0]) + # local_max[1] = fmax(local_max[1], res[i + 1]) + # local_max[0] = fmax(local_max[0], local_max[1]) + # return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) + local_max = [res[0], res[1], res[2], res[3]] + for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + local_max[0] = fmax(local_max[0], res[i + 0]) + local_max[1] = fmax(local_max[1], res[i + 1]) + local_max[2] = fmax(local_max[2], res[i + 2]) + local_max[3] = fmax(local_max[3], res[i + 3]) + local_max[0] = fmax(local_max[0], local_max[1]) + local_max[2] = fmax(local_max[2], local_max[3]) + local_max[0] = fmax(local_max[0], local_max[2]) + return local_max[0] if cutlass.const_expr(init_val is None) else fmax(local_max[0], init_val) + else: + # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max + # We instead force the 3-input max. + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_max = [ + fmax(init_val, res[0], res[1]) + if cutlass.const_expr(init_val is not None) + else fmax(res[0], res[1]), + fmax(res[2], res[3]), + fmax(res[4], res[5]), + fmax(res[6], res[7]), + ] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_max[0] = fmax(local_max[0], res[i], res[i + 1]) + local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) + local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) + local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) + local_max[0] = fmax(local_max[0], local_max[1]) + return fmax(local_max[0], local_max[2], local_max[3]) + + +@cute.jit +def fadd_reduce( + x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 +) -> Float32: + if cutlass.const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): + if cutlass.const_expr(init_val is None): + init_val = Float32.zero + return x.reduce(cute.ReductionOp.ADD, init_val, 0) + # res = cute.make_fragment(x.shape, Float32) + # res.store(x) + # local_sum = [res[0], res[1], res[2], res[3]] + # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): + # local_sum[0] += res[i + 0] + # local_sum[1] += res[i + 1] + # local_sum[2] += res[i + 2] + # local_sum[3] += res[i + 3] + # local_sum[0] += local_sum[1] + # local_sum[2] += local_sum[3] + # local_sum[0] += local_sum[2] + # return local_sum[0] if cutlass.const_expr(init_val is None) else local_sum[0] + init_val + else: + res = cute.make_fragment(x.shape, Float32) + res.store(x) + local_sum_0 = ( + cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) + # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) + if cutlass.const_expr(init_val is not None) + else (res[0], res[1]) + ) + local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] + for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) + local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) + local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) + local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) + local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) + return local_sum[0][0] + local_sum[0][1] + + +@dsl_user_op +def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: + # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() + # # cache_hint = cutlass.Int64(0x12F0000000000000) + # llvm.inline_asm( + # None, + # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], + # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], + # "red.global.add.f32 [$0], $1;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", + # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", + # "l,f", + # # "l,f,l", + # has_side_effects=True, + # is_align_stack=False, + # asm_dialect=llvm.AsmDialect.AD_ATT, + # ) + nvvm.atomicrmw( + res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() + ) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@dsl_user_op +def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(x.stride) + assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_fragment( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +@dsl_user_op +def cp_async_mbarrier_arrive_shared( + mbar_ptr: cute.Pointer, noinc: bool = False, *, loc=None, ip=None +) -> None: + nvvm.cp_async_mbarrier_arrive_shared( + mbar_ptr.llvm_ptr, + noinc=noinc, + loc=loc, + ip=ip, + ) + + +def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: + warp_group_idx = cute.arch.thread_idx()[0] // 128 + if cutlass.const_expr(sync): + warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) + return warp_group_idx + + +# @dsl_user_op +# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: +# mask = cutlass.Int32(-1) +# return cutlass.Boolean( +# llvm.inline_asm( +# T.i32(), +# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], +# ".pred p1, p2;\n" +# "setp.lt.f32 p1, $1, $2;\n" +# "vote.sync.any.pred p2, p1, $3;\n" +# "selp.u32 $0, 1, 0, p2;", +# # "selp.u32 $0, 1, 0, p1;", +# "=r,f,f,r", +# has_side_effects=False, +# is_align_stack=False, +# asm_dialect=llvm.AsmDialect.AD_ATT, +# ) +# ) + + +@cute.jit +def shuffle_sync( + value: cute.Numeric, + offset: cute.typing.Int, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Numeric: + assert value.width % 32 == 0, "value type must be a multiple of 32 bits" + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + val = cute.make_fragment(1, type(value)) + val[0] = value + val_i32 = cute.recast_tensor(val, cutlass.Int32) + for i in cutlass.range_constexpr(cute.size(val_i32)): + val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) + return val[0] + + +@dsl_user_op +def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: + return cutlass.Uint32( + llvm.inline_asm( + T.i32(), + [cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)], + "shr.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: + if cutlass.const_expr(lane is None): + lane = cute.arch.lane_idx() + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) + return val + + +@dsl_user_op +def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32: + assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], + f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def cvt_f16(src: cute.Tensor, dst: cute.Tensor): + assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" + assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" + assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16" + assert src.element_type is Float32, "src must be Float32" + dst_i32 = cute.recast_tensor(dst, cutlass.Int32) + assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) + for i in cutlass.range_constexpr(cute.size(dst_i32)): + dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) + + +@dsl_user_op +def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + out_f32x2 = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], + "{\n\t" + ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" + ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" + ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" + "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" + "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" + "mov.b64 l1, {f1, f2};\n\t" + "mov.f32 f3, 0f4B400000;\n\t" + "mov.b64 l2, {f3, f3};\n\t" + "add.rm.ftz.f32x2 l7, l1, l2;\n\t" + "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" + "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" + "mov.f32 f7, 0f3D9DF09D;\n\t" + "mov.b64 l6, {f7, f7};\n\t" + "mov.f32 f6, 0f3E6906A4;\n\t" + "mov.b64 l5, {f6, f6};\n\t" + "mov.f32 f5, 0f3F31F519;\n\t" + "mov.b64 l4, {f5, f5};\n\t" + "mov.f32 f4, 0f3F800000;\n\t" + "mov.b64 l3, {f4, f4};\n\t" + "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" + "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" + "mov.b64 {r1, r2}, l7;\n\t" + "mov.b64 {r3, r4}, l10;\n\t" + "shl.b32 r5, r1, 23;\n\t" + "add.s32 r7, r5, r3;\n\t" + "shl.b32 r6, r2, 23;\n\t" + "add.s32 r8, r6, r4;\n\t" + "mov.b32 $0, r7;\n\t" + "mov.b32 $1, r8;\n\t" + "}\n", + "=r,=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) + out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) + return out0, out1 +@dsl_user_op +def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + assert isinstance(tensor.iterator, cute.Pointer) + # We assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + elem_pointer(tensor, coord).toint(), + tensor.memspace, + assumed_align=tensor.iterator.alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len( + flat_stride + ), "Coordinate and stride must have the same length" + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@dsl_user_op +def coord_offset_i64( + tensor: cute.Tensor, idx: cute.typing.Int, dim: int, *, loc=None, ip=None +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))) + return cute.make_tensor(new_ptr, new_layout) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 30134990d6..535bd41674 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -38,11 +38,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): return 64 if (not is_dropout and is_causal) else 32 else: return 64 if not is_dropout else 32 - elif head_dim <= 160: - if is_sm8x: - return 64 - else: - return 32 elif head_dim <= 192: return 64 elif head_dim <= 224: @@ -1581,7 +1576,7 @@ def flash_attn_with_kvcache( softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile new file mode 100644 index 0000000000..29a2c0c43e --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/Dockerfile @@ -0,0 +1,17 @@ +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install triton +RUN pip install triton==3.2.0 + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md index 798d78a12d..2d8fd8e70f 100644 --- a/flash_attn/flash_attn_triton_amd/README.md +++ b/flash_attn/flash_attn_triton_amd/README.md @@ -11,39 +11,103 @@ These features are supported in Fwd and Bwd 2) Variable sequence lengths 3) Arbitrary Q and KV sequence lengths 4) Arbitrary head sizes +5) Multi and grouped query attention +6) Dropout +7) Rotary embeddings +8) ALiBi -These features are supported in Fwd for now. We will add them to backward soon. -1) Multi and grouped query attention -2) ALiBi and matrix bias - -These features are in development +We are working on the following things 1) Paged Attention 2) Sliding Window -3) Rotary embeddings -4) Dropout -5) Performance Improvements +3) FP8 +4) Performance Improvements -#### Getting Started +##### Getting Started To get started with the triton backend for AMD, follow the steps below. -First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/3ca2f498e98ed7249b82722587c511a5610e00c4). +First install the recommended Triton version ``` -git clone https://github.com/triton-lang/triton -cd triton -git checkout 3ca2f498e98ed7249b82722587c511a5610e00c4 -pip install --verbose -e python +pip install triton==3.2.0 ``` -Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. +Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. ``` -export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" cd flash-attention -python setup.py install -pytest tests/test_flash_attn.py +git checkout main_perf +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install +``` + +To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py +``` + +You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` +``` +FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE +``` + +###### Docker +You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. +``` +FROM rocm/pytorch:latest + +WORKDIR /workspace + +# install triton +RUN pip install triton==3.2.0 + +# install flash attention +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + +RUN git clone https://github.com/ROCm/flash-attention.git &&\ + cd flash-attention &&\ + git checkout main_perf &&\ + python setup.py install + +# set working dir +WORKDIR /workspace/flash-attention ``` -#### Credits +To build the docker file +``` +docker build -t fa_triton . +``` + +To run the docker image +``` +docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton +``` + +###### FP8 +In our fork We have created the following api functions that use fp8 to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. To use these functions just call them with like the other api functions, the casting will be handled internally. For example + +``` +from flash_attn import flash_attn_qkvpacked_fp8_func + +# forward pass +out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + +# backward pass +do = torch.randn_like(out) +dqkv = torch.autograd.grad(out, (qkv), do) +``` + +You can use the other api functions in a similar way. + + + +##### Credits AMD Triton kernels team OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py old mode 100644 new mode 100755 index 91939f831f..05e64c349b --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -1,290 +1,1223 @@ -import argparse +import os +import sys import torch import triton -from flash_attn.flash_attn_triton_amd.utils import ( - MetaData, - input_helper, - varlen_input_helper, -) -from flash_attn.flash_attn_triton_amd.interface_torch import attention_prefill, attention_decode - -ARGS_TO_TORCH_DTYPE = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "fp32": torch.float32, +import time +import argparse +import itertools +import pandas as pd +from logging import warning +from typing import Dict, List, Literal, Optional, Tuple +from dataclasses import dataclass +from functools import lru_cache +from utils import get_arch, input_helper + +DEBUG = False + +ENV_FLAGS = ["FLASH_ATTENTION_TRITON_AMD_ENABLE", "FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "FLASH_ATTENTION_TRITON_AMD_DEBUG"] + +FUNCTIONS = [ + "flash_attn_func", + "flash_attn_fp8_func", + "flash_attn_kvpacked_func", + "flash_attn_qkvpacked_func", + "flash_attn_qkvpacked_fp8_func", + "flash_attn_varlen_func", + "flash_attn_varlen_fp8_func", + "flash_attn_varlen_kvpacked_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_qkvpacked_fp8_func", + "flash_attn_with_kvcache", +] + +SUPPORTED_DTYPES = { + "flash_attn_func": [torch.float16], + "flash_attn_fp8_func": [torch.float8_e4m3fnuz], + "flash_attn_kvpacked_func": [torch.float16], + "flash_attn_qkvpacked_func": [torch.float16], + "flash_attn_qkvpacked_fp8_func": [torch.float16], + "flash_attn_varlen_func": [torch.float16], + "flash_attn_varlen_fp8_func": [torch.float8_e4m3fnuz], + "flash_attn_varlen_kvpacked_func": [torch.float16], + "flash_attn_varlen_qkvpacked_func": [torch.float16], + "flash_attn_varlen_qkvpacked_fp8_func": [torch.float16], + "flash_attn_with_kvcache": [torch.float16], +} + +SUPPORTED_BACKENDS = { + "flash_attn_func": ["ck", "triton"], + "flash_attn_fp8_func": ["triton"], + "flash_attn_kvpacked_func": ["ck", "triton"], + "flash_attn_qkvpacked_func": ["ck", "triton"], + "flash_attn_qkvpacked_fp8_func": ["triton"], + "flash_attn_varlen_func": ["ck", "triton"], + "flash_attn_varlen_fp8_func": ["triton"], + "flash_attn_varlen_kvpacked_func": ["ck", "triton"], + "flash_attn_varlen_qkvpacked_func": ["ck", "triton"], + "flash_attn_varlen_qkvpacked_fp8_func": ["triton"], + "flash_attn_with_kvcache": ["ck", "triton"], } -FUNCTIONS = { - "prefill": attention_prefill, - "decode": attention_decode +VALID_MODES = ['fwd', 'bwd', 'full'] +SUPPORTED_MODES = { + "flash_attn_func": ["fwd", "bwd", "full"], + "flash_attn_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_kvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_qkvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_qkvpacked_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_kvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_qkvpacked_func": ["fwd", "bwd", "full"], + "flash_attn_varlen_qkvpacked_fp8_func": ["fwd", "bwd", "full"], + "flash_attn_with_kvcache": ["fwd"], } -def get_benchmark_configs(args, varlen=False): +@dataclass +class EnvVariableConfig: + key: str + values: List[str] + backend: Optional[Literal["triton", "ck"]] = None + +ENV_VARIABLE_CONFIGS : List[EnvVariableConfig] = [ + EnvVariableConfig(key="BWD_MODE", values=["split", "fused", "jingning"], backend="triton"), +] + +class FunctionConfig: + def __init__(self, fn_name: str, mode: Literal["fwd", "bwd", "full"], dtype, backend: Literal["triton", "ck"], env_config: Dict): + self.fn_name = fn_name + self.mode: Literal["fwd", "bwd", "full"] = mode + self.dtype = dtype + self.backend: Literal["triton", "ck"] = backend + self.arch = get_arch() + self.env_configs = env_config + + def __str__(self): + # extract base dtype name if it's a torch dtype + dtype_str = str(self.dtype) + if "torch." in dtype_str: + dtype_str = dtype_str.split(".")[-1] + + if len(self.env_configs) > 0: + env_str = "" + for env_key, env_value in self.env_configs.items(): + env_str += f"{env_key}={env_value}" + return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}_{env_str}" + else: + return f"{self.fn_name}_{self.mode}_{dtype_str}_{self.backend}_{self.arch}" + + def column_name(self): + return f"{self}_ms" + + +@lru_cache() +def available_backends(): + available = [] + + # try to load each backend + for backend in ["triton", "ck"]: + try: + # try loading the module with this backend + flash_attn = load_flash_attn_module(backend) + + # if we got here, the backend loaded successfully + available.append(backend) + except Exception as e: + # backend not available, just continue + print(f"Backend {backend} not available. Error: {e}") + + # if no backends available, default to triton + if not available: + raise ValueError("No Backends available") + + return available + +@lru_cache() +def get_fn_params(fn_name): + # get params for fn + packing = get_packing_type(fn_name) + is_varlen = True if "varlen" in fn_name else False + is_fp8 = True if "fp8" in fn_name else False + supported_dtypes = SUPPORTED_DTYPES.get(fn_name, [torch.float16]) # default to float16 if not found + supported_backends = [backend for backend in SUPPORTED_BACKENDS.get(fn_name, ["triton"]) if backend in available_backends()] # default to triton backend + supports_backward = False if fn_name in ["flash_attn_with_kvcache"] else True + supported_modes = SUPPORTED_MODES.get(fn_name, ["fwd"]) + device = "cuda" + + # get supported env configs for each backend + supported_env_configs = {} + for backend in supported_backends: + supported_env_configs[backend] = get_env_value_combinations(backend) + + # check backward pass support + if not supports_backward: + warning(f"{fn_name} does not have a backward pass so benching forward pass only.") + + return is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes, supported_env_configs, device + +def generate_fn_inputs( + fn_name: str, + BATCH: int, + HQ: int, + HK: int, + N_CTX_Q: int, + N_CTX_K: int, + D_HEAD: int, + CAUSAL: bool, + DROPOUT_P: float, + dtype: torch.dtype, + device: Literal["cpu", "cuda"] + ): + if fn_name == "flash_attn_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_kvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="kv", device=device) + elif fn_name == "flash_attn_qkvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) + elif fn_name == "flash_attn_varlen_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) + elif fn_name == "flash_attn_varlen_kvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="kv", device=device) + elif fn_name == "flash_attn_varlen_qkvpacked_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) + elif fn_name == "flash_attn_with_kvcache": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", device=device) + elif fn_name == "flash_attn_qkvpacked_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="bshd", packing="qkv", device=device) + elif fn_name == "flash_attn_varlen_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", device=device) + elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": + return input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT_P, dtype, layout="thd", packing="qkv", device=device) + else: + valid_fn_names = ", ".join(FUNCTIONS) + raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") + +def estimate_memory(config): + batch, hq, hk, sq, sk, d_head, causal, dropout = config + memory_estimate = batch * (hq * sq + hk * sk) * d_head * 4 # bytes + return memory_estimate + +def generate_benchmark_configs(is_varlen: bool, packing: Optional[Literal["kv", "qkv"]]): """ - Returns benchmark configurations based on whether variable-length sequences are used. + generates a small number of configs that cover the parameter space well """ - if args.custom_config: - hk = args.hq if not args.hk else args.hk - sk = args.sq if not args.sk else args.sk - return [(args.b, args.hq, hk, args.sq, sk)] - elif varlen: - return [ - (2, 16, 4, 1024, 1024), - (8, 16, 2, 2048, 2048), - (4, 16, 8, 4096, 4096), - (2, 16, 4, 8192, 8192), - (2, 16, 8, 16384, 16384), - (2, 48, 12, 1024, 1024), - (2, 48, 24, 2048, 2048), - (2, 48, 8, 4096, 4096), - (2, 48, 4, 8192, 8192), - (2, 48, 2, 16384, 16384), - (2, 64, 32, 1024, 1024), - (4, 64, 16, 2048, 2048), - (4, 64, 8, 4096, 4096), - (4, 64, 32, 8192, 8192), - (4, 128, 16, 16384, 16384), - ] + + # define all parameter options as lists + batch_sizes = [1, 64] + if packing == "qkv": + hq_values = hk_values = [2, 8] + sq_values = sk_values = [256, 8192] else: - return [ - (16, 16, 16, 1024, 1024), - (8, 16, 16, 2048, 2048), - (4, 16, 16, 4096, 4096), - (1, 8, 8, 8192, 8192), - (1, 2, 2, 16384, 16384), - (2, 48, 48, 1024, 1024), - (2, 48, 48, 2048, 1024), - (1, 8, 8, 4096, 8192), - (1, 8, 8, 8192, 4096), - (2, 4, 4, 16384, 8192), - (2, 8, 8, 1989, 15344), - (4, 16, 16, 4097, 163), - (2, 16, 16, 8122, 2159), - (1, 16, 16, 16281, 7), - (2, 48, 48, 1021, 1020), - (2, 48, 48, 2001, 2048), - (2, 8, 8, 3996, 9639), - (2, 8, 8, 8181, 1021), - ] + if is_varlen: # make sure the seqlen is greater than the batchsize so that subsequences are greater than 0 + hq_values = [16, 32] # test mqa/gqa + hk_values = [8, 16] + sq_values = [128, 512] + sk_values = [512, 2024] + else: + hq_values = [64, 128] # test mqa/gqa + hk_values = [16, 64] + sq_values = [4, 4096] + sk_values = [4096, 16384] # test large k values for inference perf + d_head_values = [64, 128] + causal_values = [True, False] # most models usual causal True + dropout_values = [0.0, 0.1] + + # generate all fn_configs without inputs + input_configs = [] + + # one big loop to generate configs + for batch in batch_sizes: + for hq in hq_values: + for hk in hk_values: + for sq in sq_values: + for sk in sk_values: + for d_head in d_head_values: + for causal in causal_values: + for dropout in dropout_values: + # filter configs + input_config = (batch, hq, hk, sq, sk, d_head, causal, dropout) + + # skip if memory usage would be too high + if estimate_memory(input_config) > 8 * 1024 * 1024 * 1024: # 8 GB limit + continue + + # we need hq to be a multiple of hk + if hq % hk != 0: + continue + + # for qkvpacked functions, q and k must have same dimensions + if packing == "qkv" and (sq != sk or hq != hk): + continue + + input_configs.append(input_config) + + return input_configs + +def create_benchmark_fn( + flash_attn, + fn_name, + fn_input, + mode: Literal["fwd", "bwd", "full"] +): + if DEBUG: + print("create_benchmark_fn") + print("flash_attn:", flash_attn) + print("fn_name:", fn_name) + print("fn_input:", len(fn_input)) + print("mode:", mode) + + if fn_name == "flash_attn_func": + q, k, v, do, metadata = fn_input + if mode == "fwd": + def flash_attn_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_bench_fn(): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + elif mode == "full": + def flash_attn_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_bench_fn -def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal): - flops_per_matmul = 0 - - if fn_name.startswith("prefill"): - if layout == "thd": - q, k, v, input_metadata = varlen_input_helper( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device) - for i in range(input_metadata.num_contexts): - seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i] - seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i] - flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2 + elif fn_name == "flash_attn_kvpacked_func": + q, kv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_kvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_kvpacked_bench_fn(): + dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) + return dq, dkv + elif mode == "full": + def flash_attn_kvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_kvpacked_func( + q, + kv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dkv = torch.autograd.grad(out, (q, kv), do, retain_graph=True) + return dq, dkv else: - q, k, v, input_metadata = input_helper( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_kvpacked_bench_fn + elif fn_name == "flash_attn_qkvpacked_func": + qkv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_qkvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_qkvpacked_bench_fn(): + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + elif mode == "full": + def flash_attn_qkvpacked_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_qkvpacked_bench_fn + elif fn_name == "flash_attn_varlen_func": + q_unpad, k_unpad, v_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_bench_fn(): + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + elif mode == "full": + def flash_attn_varlen_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_bench_fn + elif fn_name == "flash_attn_varlen_kvpacked_func": + q_unpad, kv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_kvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, ) - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD - - if causal: - input_metadata.need_causal() - - o = torch.empty_like(q) - input_data = (q, k, v, o, input_metadata) - elif fn_name.startswith("decode"): - q = torch.randn( - [BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ) - k = torch.randn( - [BATCH, N_CTX_K, HK, 1, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ).expand(-1, -1, -1, HQ // HK, -1) - v = torch.randn( - [BATCH, N_CTX_K, HK, 1, D_HEAD], - device=device, - dtype=dtype, - requires_grad=False, - ).expand(-1, -1, -1, HQ // HK, -1) - input_metadata = MetaData(sm_scale=1.3) - input_metadata.layout = "bsghd" - - # Adjust flops calculation if needed - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + def flash_attn_varlen_kvpacked_bench_fn(): + dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) + return dq_unpad, dkv_unpad + elif mode == "full": + def flash_attn_varlen_kvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_kvpacked_func( + q_unpad, + kv_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dkv_unpad = torch.autograd.grad(out_unpad, (q_unpad, kv_unpad), do_unpad, retain_graph=True) + return dq_unpad, dkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_kvpacked_bench_fn + elif fn_name == "flash_attn_varlen_qkvpacked_func": + qkv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_qkvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_qkvpacked_bench_fn(): + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + elif mode == "full": + def flash_attn_varlen_qkvpacked_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_qkvpacked_bench_fn + elif fn_name == "flash_attn_with_kvcache": + q, k_cache, v_cache, _, metadata = fn_input + if mode == "fwd": + def flash_attn_with_kvcache_bench_fn(): + out = flash_attn.flash_attn_with_kvcache( + q, + k_cache, + v_cache, + None, + None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens=None, + cache_batch_idx=None, + cache_leftpad=None, + block_table=None, + causal=metadata.causal, + window_size=(-1, -1), + rotary_interleaved=False, + alibi_slopes=None, + num_splits=0, + ) + return out + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_with_kvcache_bench_fn + elif fn_name == "flash_attn_fp8_func": + (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata = fn_input + if mode == "fwd": + def flash_attn_f8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_f8_bench_fn(): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + elif mode == "full": + def flash_attn_f8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_fp8_func( + q, + k, + v, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do, retain_graph=True) + return dq, dk, dv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") - input_data = (q, k, v, input_metadata) + return flash_attn_f8_bench_fn + elif fn_name == "flash_attn_qkvpacked_fp8_func": + qkv, do, metadata = fn_input + if mode == "fwd": + def flash_attn_qkvpacked_fp8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out + elif mode == "bwd": + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_qkvpacked_fp8_bench_fn(): + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + elif mode == "full": + def flash_attn_qkvpacked_fp8_bench_fn(): + out, lse, S_dmask = flash_attn.flash_attn_qkvpacked_fp8_func( + qkv, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv = torch.autograd.grad(out, (qkv), do, retain_graph=True) + return dqkv + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_qkvpacked_fp8_bench_fn + elif fn_name == "flash_attn_varlen_fp8_func": + (q_unpad, descale_q), (k_unpad, descale_k), (v_unpad, descale_v), (do_unpad, descale_do), metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_fp8_bench_fn(): + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + elif mode == "full": + def flash_attn_varlen_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_fp8_func( + q_unpad, + k_unpad, + v_unpad, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), do_unpad, retain_graph=True) + return dq_unpad, dk_unpad, dv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_fp8_bench_fn + elif fn_name == "flash_attn_varlen_qkvpacked_fp8_func": + qkv_unpad, do_unpad, metadata = fn_input + if mode == "fwd": + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + return out_unpad + elif mode == "bwd": + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + elif mode == "full": + def flash_attn_varlen_qkvpacked_fp8_bench_fn(): + out_unpad, lse, S_dmask = flash_attn.flash_attn_varlen_qkvpacked_fp8_func( + qkv_unpad, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + metadata.dropout_p, + causal=metadata.causal, + window_size=(-1, -1), + softcap=0.0 , + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + ) + dqkv_unpad = torch.autograd.grad(out_unpad, (qkv_unpad), do_unpad, retain_graph=True) + return dqkv_unpad + else: + raise ValueError(f"Unsupported benchmarking mode: {mode}") + + return flash_attn_varlen_qkvpacked_fp8_bench_fn else: - raise ValueError("Unsupported benchmark function") - return input_data, flops_per_matmul + valid_fn_names = ", ".join(FUNCTIONS) + raise ValueError(f"{fn_name} should be one of the following functions. {valid_fn_names}") -def run_benchmark(args, fn_name, fn, mode): +def get_packing_type(fn_name: str) -> Optional[Literal["kv", "qkv"]]: + if "_kvpacked" in fn_name: + packing = "kv" + elif "_qkvpacked" in fn_name: + packing = "qkv" + else: + packing = None + + return packing + +def load_flash_attn_module(backend: Literal["triton", "ck"], env_configs: Dict = {}, verbose = False): """ - Runs the benchmark for the provided function based on the provided arguments. + Load the flash_attn module with the specified backend configuration """ - print(f"Benchmarking {fn_name} in {mode} mode...") - dtype = ARGS_TO_TORCH_DTYPE[args.dtype] - head_size = args.d if args.d else 128 - causal = args.causal - varlen = args.layout == "thd" - return_tflops = args.return_tflops - line_names = "TFLOPS" if return_tflops else "Time (ms)" + # remove any existing env variables first + for key in ENV_FLAGS: + if key in os.environ: + del os.environ[key] - # Determine configurations - x_vals_list = get_benchmark_configs(args, varlen=varlen) + # set environment variable for the desired backend + if backend == "triton": + os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "TRUE" + os.environ["FLASH_ATTENTION_TRITON_AMD_AUTOTUNE"] = "0" + os.environ["FLASH_ATTENTION_TRITON_AMD_DEBUG"] = "0" + elif backend == "ck": + os.environ["FLASH_ATTENTION_TRITON_AMD_ENABLE"] = "FALSE" + else: + raise ValueError(f"Unknown backend {backend}") + + # add custom env configs + add_env_configs(env_configs) + + if verbose: + print(f"Loading flash_attn module with {backend} backend.") + + # Remove any existing flash_attn modules from sys.modules + for module_name in list(sys.modules.keys()): + if module_name.startswith('flash_attn'): + del sys.modules[module_name] + + # Clear CUDA cache + torch.cuda.empty_cache() + + # Import and return the module + import flash_attn + + return flash_attn + +def add_env_configs(env_config: Dict): + for env_key, env_value in env_config.items(): + if env_key in os.environ: + del os.environ[env_key] # remove previous version so that env key is the latest key added + os.environ[env_key] = env_value + +def run_benchmark(func_config: FunctionConfig, input_configs): + """ + Runs the benchmark for the provided function configuration with the given input configurations. + """ + # print new line to seperate benchmark runs + print() + if DEBUG: + print("func_config:", func_config) + + # extract function configuration parameters + fn_name = func_config.fn_name + mode = func_config.mode + dtype = func_config.dtype + backend = func_config.backend + + # load flash attention module + flash_attn_module = load_flash_attn_module(backend, func_config.env_configs, verbose=True) + + # start timing the benchmark + start_time = time.time() + + # print bench fn + print(f"Benchmarking {func_config} ...") # Setup benchmark configurations - configs = [ + bench_configs = [ triton.testing.Benchmark( - x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"], - x_vals=x_vals_list, + x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"], + x_vals=list(input_configs.keys()), line_arg="provider", line_vals=["triton"], - line_names=[line_names], + line_names=["Time (ms)"], styles=[("red", "-")], ylabel="ms", - plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}", + plot_name=f"benchmark-{func_config}", args={ - "D_HEAD": head_size, - "dtype": dtype, - "causal": causal, - "mode": mode, }, ) ] - @triton.testing.perf_report(configs) + @triton.testing.perf_report(bench_configs) def bench_function( - BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda" + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT, provider, device="cuda" ): - warmup = 25 - rep = 100 - flops_per_matmul = 0 + if DEBUG: + print("BATCH:", BATCH) + print("HQ:", HQ) + print("HK:", HK) + print("N_CTX_Q:", N_CTX_Q) + print("N_CTX_Q:", N_CTX_Q) + print("D_HEAD:", D_HEAD) + print("CAUSAL:", CAUSAL) + print("DROPOUT:", DROPOUT) + print("mode:", mode) + print("provider:", provider) + print("device:", device) + fn_input = input_configs[(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, CAUSAL, DROPOUT)] + benchmark_fn = create_benchmark_fn(flash_attn_module, fn_name, fn_input, mode) - # generate function inputs - fn_inputs, flops_per_matmul = gen_fn_inputs( - fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal - ) + # run the benchmark + ms = triton.testing.do_bench(benchmark_fn, warmup=25, rep=100) + return ms - # define the function to benchmark - if mode == "fwd": - benchmark_fn = lambda: fn(*fn_inputs) - total_flops = 2 * flops_per_matmul - elif mode == "bwd": - outputs = fn(*fn_inputs) - output = outputs[0] - grad_output = torch.randn_like(output) - benchmark_fn = lambda: output.backward(grad_output, retain_graph=True) - total_flops = 2 * flops_per_matmul * 2.5 - else: - raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.") + df = bench_function.run(save_path=".", print_data=True, return_df=True)[0] + + # set the column name to reflect the function configuration + df = df.rename(columns={"Time (ms)": func_config.column_name()}) + + # calculate and print elapsed time + elapsed_time = time.time() - start_time + print(f"Total time for benchmarking {fn_name} in {mode} mode with {dtype}: {elapsed_time:.2f} seconds") - if causal: - total_flops *= 0.5 + return df - # Run the benchmark - ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep) +def filter_modes(requested_modes, fn_name, supported_modes_for_fn): + modes_to_run = [] + if requested_modes: + for mode in requested_modes: + if mode in supported_modes_for_fn: + modes_to_run.append(mode) + else: + warning(f"Mode '{mode}' requested but not supported by function '{fn_name}'. Skipping this mode for this function.") + else: + modes_to_run = ["full" if "full" in supported_modes_for_fn else "fwd"] + return modes_to_run - if return_tflops: - return total_flops / ms * 1e-9 - else: - return ms +def get_env_value_combinations(current_backend: Optional[Literal["triton", "ck"]]) -> List[Dict[str, str]]: + # filter environment variations applicable to the current backend + applicable_variations = [ + var_config for var_config in ENV_VARIABLE_CONFIGS + if var_config.backend is None or var_config.backend == current_backend + ] - bench_function.run(save_path=".", print_data=True) + if not applicable_variations: + # no applicable variations, return list with empty dict + return [{}] -def supported_layouts(): - """ - Returns a string describing the supported layouts. - """ - return ( - "bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n" - "bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n" - "thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n" - 'This layout is sometimes called "varlen" or "grouped" layout.' - ) + # prepare keys and value lists + variation_keys = [v.key for v in applicable_variations] + variation_value_lists = [v.values for v in applicable_variations] + + # generate all combinations as dictionaries directly + env_configs = [] + for value_combination in itertools.product(*variation_value_lists): + env_configs.append(dict(zip(variation_keys, value_combination))) + + return env_configs + +def get_input_config_set(config_type): + if config_type == "llama": + # batch, hq, hk, sq, sk, d_head, causal, dropout + input_configs = [ + # LLaMA 3 8B + (1, 32, 8, 8192, 8192, 128, True, 0.0), + # LLaMA 3 70B + (1, 64, 8, 8192, 8192, 128, True, 0.0), + ] + else: + raise ValueError(f"Unknown input config: {config_type}") + + return input_configs -def parse_args(): + +def process_args(): """ - Parses command-line arguments. + Parses command-line arguments and returns function configs and input configs. """ + # create parser parser = argparse.ArgumentParser( prog="Benchmark FlashAttention", allow_abbrev=False, ) - parser.add_argument("-b", type=int, default=0) - parser.add_argument("-hq", type=int, default=0) - parser.add_argument("-hk", type=int, default=0) - parser.add_argument("-sq", type=int, default=0) - parser.add_argument("-sk", type=int, default=0) - parser.add_argument( - "-equal_seqlens", - action="store_true", - default=False, - help="If specified, each context within the thd layout has same seqlen as sq and sk", - ) - parser.add_argument("-d", type=int, default=0) - parser.add_argument("-causal", action="store_true", default=False) - parser.add_argument("-dtype", default="fp16") - parser.add_argument("-return_tflops", action="store_true", default=False) - parser.add_argument( - "-layout", - type=str, - default="bhsd", - help=supported_layouts(), - ) + # functions parser.add_argument( "-benchmark_fn", type=str, nargs="*", - choices=FUNCTIONS.keys(), - help="Function(s) to benchmark: prefill, decode, or both", + choices=FUNCTIONS, + required=True, + help=f"Function(s) to benchmark", ) parser.add_argument( - "-mode", + "--mode", type=str, nargs='*', - default=["fwd", "bwd"], - choices=["fwd", "bwd"], - help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass", + choices=VALID_MODES, + default=None, + help=f"Benchmarking mode(s) to run. If omitted, runs all supported modes for each function.", ) - return parser.parse_args() + # config + parser.add_argument("-b", type=int, default=None, help="Batch size") + parser.add_argument("-hq", type=int, default=None, help="Q Number of heads") + parser.add_argument("-hk", type=int, default=None, help="K and V Number of heads") + parser.add_argument("-sq", type=int, default=None, help="Q Sequence Length") + parser.add_argument("-sk", type=int, default=None, help="K and V Sequence Length") + parser.add_argument("-d", type=int, default=None, help="Head Dimension") + parser.add_argument("-causal", action="store_true", default=None, help="Causal") + parser.add_argument("-dropout", type=float, default=None, help="Dropout") + + # parse args + args = parser.parse_args() + + # parse function args + benchmark_fns = args.benchmark_fn + requested_modes = args.mode + + # fenerate function configurations and input configurations separately + all_function_configs = [] + all_input_configs = {} # Maps function config -> input configs + for fn_name in benchmark_fns: + is_varlen, is_fp8, packing, supported_dtypes, supported_backends, supported_modes_for_fn, supported_env_configs, device = get_fn_params(fn_name) + + # Generate or use custom input configurations + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + assert args.b and args.hq and args.sq and args.d, ( + "if custom config is specified, please provide at least batch, number of Q heads, Q sequence length, and head size." + ) + + batch = args.b + hq = args.hq + hk = args.hk if args.hk is not None else args.hq + sq = args.sq + sk = args.sk if args.sk is not None else args.sq + d_head = args.d + causal = args.causal if args.causal is not None else False + dropout = args.dropout if args.dropout is not None else 0.0 + input_configs = [(batch, hq, hk, sq, sk, d_head, causal, dropout)] + else: + if True: + input_configs = get_input_config_set("llama") + else: + input_configs = generate_benchmark_configs(is_varlen, packing) + + # filter by mode + modes_to_run = filter_modes(requested_modes, fn_name, supported_modes_for_fn) + if not modes_to_run: + warning(f"No valid modes to run for function '{fn_name}' based on request and function support. Skipping this function.") + continue + + # create a function config for each backend and dtype combination + for backend in supported_backends: + for dtype in supported_dtypes: + for mode in modes_to_run: + for env_config in supported_env_configs[backend]: + func_config = FunctionConfig(fn_name, mode, dtype, backend, env_config) + all_function_configs.append(func_config) + + # Generate inputs for this function configuration + fn_inputs = {} + for input_config in input_configs: + fn_inputs[input_config] = generate_fn_inputs(fn_name, *input_config, dtype, device) + + all_input_configs[func_config] = fn_inputs + + return all_function_configs, all_input_configs + +def check_environment_variables(): + for key in ENV_FLAGS: + if key in os.environ: + raise ValueError(f"Running with {key} environment variable is not recommended for the benching script. Use --help to see how to use the benching script.") def main(): """ Main function to run benchmarks. """ - args = parse_args() - - # Validate arguments - assert ( - args.layout == "thd" or not args.equal_seqlens - ), "Equal sequence lengths arg must be used with the thd layout." - args.custom_config = False - if args.b or args.hq or args.hk or args.sq or args.sk or args.d: - args.custom_config = True - assert args.b and args.hq and args.sq and args.d, ( - "If custom config is specified, please provide all of batch, " - "number of Q heads, Q sequence length, and head size." - ) - assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported." + # check environment variables + check_environment_variables() - # determine the functions to benchmark - if args.benchmark_fn is None or len(args.benchmark_fn) == 0: - bench_fn_list = FUNCTIONS.keys() - else: - bench_fn_list = args.benchmark_fn - - # benchmark functions - for fn_name in bench_fn_list: - if fn_name not in FUNCTIONS: - raise ValueError(f"Invalid benchmark function specified: {fn_name}") - for mode in args.mode: - if fn_name == "decode" and mode == "bwd": - print(f"Decode kernel doesnot have a backward pass") - continue - run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode) + # start timing the entire benchmarking process + total_start_time = time.time() + + # process args to get function configs and input configs + function_configs, all_input_configs = process_args() + + # Check if we have multiple function configurations + has_multiple_func_configs = len(function_configs) > 1 + combined_df = None + + # run benchmarks for each function configuration + for func_config in function_configs: + # run benchmark with the input configs for this function config + input_configs = all_input_configs[func_config] + df = run_benchmark(func_config, input_configs) + + # Define the columns that represent input configurations + input_config_cols = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K", "D_HEAD", "CAUSAL", "DROPOUT"] + + # merge into one final dataframe + if combined_df is None: + combined_df = df + else: + # Ensure we're joining on input configuration columns + combined_df = combined_df.merge(df, on=input_config_cols, how="outer") + + + # print new line to seperate the combined data information from the benchmark specific information + print() + + # print total time for all benchmarks + total_elapsed_time = time.time() - total_start_time + print(f"Total time for all benchmarks: {total_elapsed_time:.2f} seconds") + + # save combined data and make comparisons if we have multiple function configs + if has_multiple_func_configs: + if len(function_configs) == 2: + func1 = function_configs[0] + func2 = function_configs[1] + + # construct column names for the timing results + col1 = func1.column_name() + col2 = func2.column_name() + + # Check if we're comparing triton vs ck (in either order) + is_triton_vs_ck = ( + (func1.backend == "triton" and func2.backend == "ck") or + (func1.backend == "ck" and func2.backend == "triton") + ) + + # For triton vs ck comparisons + if is_triton_vs_ck: + # For triton vs ck comparisons, always make triton the baseline + if func1.backend == "triton" and func2.backend == "ck": + triton_col = col1 + ck_col = col2 + ratio_col = f"ck_to_triton_ratio" + else: + triton_col = col2 + ck_col = col1 + ratio_col = f"ck_to_triton_ratio" + + # Calculate ratio: ck_time / triton_time (values > 1 mean triton is faster) + combined_df[ratio_col] = combined_df[ck_col] / combined_df[triton_col] + + # print explanation + print(f"Comparison Results (triton vs ck):") + print(f"Ratio values: values > 1 mean triton is faster (by that factor), values < 1 mean ck is faster") + elif False: + # For other comparisons, use the standard approach + ratio_col = f"{func1}_to_{func2}_ratio" + + # Calculate the ratio + combined_df[ratio_col] = combined_df[col2] / combined_df[col1] + + # print explanation + print(f"Comparison Results ({func1} vs {func2}):") + print(f"Ratio values: values > 1 mean {func1} is faster than {func2} (by that factor), values < 1 mean slower") + + print(f"Combined data:") + print(combined_df) + + # save csv & markdown + combined_filename = f"benchmark_combined" + combined_df.to_csv(f"{combined_filename}.csv", index=False) + with open(f"{combined_filename}.md", 'w') as f: + f.write(combined_df.to_markdown(index=False, floatfmt=".2f")) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 84212235a6..44e2c294b0 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -1,10 +1,16 @@ +from typing import Literal, Optional import torch import triton import triton.language as tl -from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_fp8, write_dropout_mask, create_dropout_mask + +# TODO: move this into utils.py so it's shared among kernels +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @triton.jit -def _bwd_preprocess_use_o( +def _bwd_preprocess( Out, DO, Delta, @@ -15,16 +21,18 @@ def _bwd_preprocess_use_o( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + DESCALE_do, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, N_CTX_Q: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, - IS_VARLEN: tl.constexpr + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, ): - pid_m = tl.program_id(0) - pid_bh = tl.program_id(1) + pid_bh = tl.program_id(0) + pid_m = tl.program_id(1) # Compute batch and head indices off_z = pid_bh // H @@ -62,11 +70,18 @@ def _bwd_preprocess_use_o( do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok # load - o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) - do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32) + o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) + do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0) # compute delta - delta = tl.sum(o * do, axis=1) + if IS_FP8: + stride_descale_q_z = H + descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_h) + + # NOTE: do is scaled into the fp8 range and o is in fp8 but should be in the same scale as fp32 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) # write-back delta delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam @@ -94,8 +109,9 @@ def _bwd_kernel_one_col_block( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -112,23 +128,30 @@ def _bwd_kernel_one_col_block( stride_deltaz, stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, + GROUP_SIZE: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, ): if CAUSAL: # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M @@ -154,11 +177,12 @@ def _bwd_kernel_one_col_block( k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk k = tl.load(k_ptrs, mask=kv_mask, other=0.0) - v = tl.load(v_ptrs, mask=kv_mask, other=0.0) + kT = tl.trans(k) + vT = tl.trans(tl.load(v_ptrs, mask=kv_mask, other=0.0)) # loop over rows - for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M): - offs_m = start_m + tl.arange(0, BLOCK_M) + for start_m in range(lo, num_block_m): + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk @@ -173,7 +197,10 @@ def _bwd_kernel_one_col_block( # recompute p = softmax(qk, dim=-1).T qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) + if IS_FP8: + qk += (tl.dot(q, kT) * descale_q * descale_k) + else: + qk += tl.dot(q, kT) if CAUSAL: col_offset = N_CTX_Q - N_CTX_K @@ -197,27 +224,89 @@ def _bwd_kernel_one_col_block( p_mask = mask_m[:, None] & mask_n[None, :] p = tl.where(p_mask, p, 0.0) - # compute dv - dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + if DROPOUT: + # NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing + philox_offset = batch_philox_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + # print("philox_seed:", philox_seed) + # print("philox_offset:", philox_offset) + if tl_DROPOUT_USE_PYTORCH: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load(dropout_ptrs, mask=p_mask) + else: + rand_vals = tl.rand(philox_seed, philox_offset) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1/ (1 - dropout_p) + + if tl_DROPOUT_DUMP: + dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn + tl.store(dropout_ptrs, dropout_mask, mask=p_mask) + + # apply dropout mask + p_drop = tl.where(dropout_mask, p, 0.0) + p_drop_scaled = p_drop * dropout_scale + + # compute dv + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(p_drop_scaled, FP8_MAX) + dv += (tl.dot(tl.trans(p_drop_scaled * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(tl.trans(p_drop_scaled).to(do.type.element_ty), do) + + # compute dp + if IS_FP8: + dp_drop_scaled = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp_drop_scaled = tl.dot(do, vT) + dp = tl.where(dropout_mask, dp_drop_scaled, 0.0) * dropout_scale + else: + + # compute dv + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + dv += (tl.dot(tl.trans(p * scale_p).to(do.type.element_ty), do) * descale_p * descale_do) + else: + dv += tl.dot(tl.trans(p).to(do.type.element_ty), do) - # compute dp - dp = tl.dot(do, tl.trans(v)) + # compute dp + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) - # compute ds , ds = p * (dp - delta[:, None]) - d_ptrs = d_offset + offs_m * stride_deltam - Di = tl.load(d_ptrs, mask=mask_m) - ds = (p * (dp - Di[:, None])) * sm_scale - ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty) - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds), q) + # load delta + delta_ptrs = delta_offset + offs_m * stride_deltam + delta_i = tl.load(delta_ptrs, mask=mask_m) + + # compute ds + dscores_scaled = (p * (dp - delta_i[:, None])) + ds = dscores_scaled * sm_scale + ds = tl.where(p_mask, ds, 0.0) + + # compute descale_ds + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + else: + scale_ds, descale_ds = 1.0, 1.0 + + # compute dk + if IS_FP8: + dk += (tl.dot(tl.trans(ds * scale_ds).to(q.type.element_ty), q) * descale_ds * descale_q) + else: + dk += tl.dot(tl.trans(ds).to(q.type.element_ty), q) # compute dq if SEQUENCE_PARALLEL: - dq = tl.dot(ds, k) + if IS_FP8: + dq = (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) + else: + dq = tl.dot(ds.to(k.type.element_ty), k) else: dq = tl.load(dq_ptrs, mask=q_mask, other=0.0) - dq += tl.dot(ds, k) + if IS_FP8: + dq += (tl.dot((ds * scale_ds).to(k.type.element_ty), k) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(k.type.element_ty), k) tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask) # write-back dv and dk @@ -225,8 +314,13 @@ def _bwd_kernel_one_col_block( dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk # write-back - tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) - tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + if GROUP_SIZE != 1: + # use atomic_add to properly accumulate gradients from multiple query heads + tl.atomic_add(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.atomic_add(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) + else: + tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask) + tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask) @triton.jit def _bwd_kernel( @@ -240,7 +334,12 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, + Dropout_mask, + DESCALE_q, + DESCALE_k, + DESCALE_v, + DESCALE_do, stride_dq_all, stride_qz, stride_qh, @@ -257,29 +356,44 @@ def _bwd_kernel( stride_deltaz, stride_deltah, stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, Z, - H, + HQ, + HK, num_block_m, num_block_n, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset_base, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + DROPOUT: tl.constexpr, USE_EXP2: tl.constexpr, IS_VARLEN: tl.constexpr, + GROUP_SIZE: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, ): # program ids - off_hz = tl.program_id(0) + off_zh = tl.program_id(0) if SEQUENCE_PARALLEL: start_n = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H + off_z = off_zh // HQ + off_hq = off_zh % HQ + + # check if GQA/MQA + if GROUP_SIZE != 1: + off_hk = off_hq // GROUP_SIZE + else: + off_hk = off_hq if IS_VARLEN: # Compute sequence lengths for the current batch @@ -296,23 +410,40 @@ def _bwd_kernel( k_start = 0 N_CTX_Q = max_seqlen_q N_CTX_K = max_seqlen_k - # input tensor offsets - q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm - k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn - v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn - do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm - l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam - d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam + q_offset = Q + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + k_offset = K + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + v_offset = V + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn + do_offset = DO + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm + l_offset = L + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + delta_offset = Delta + off_z * stride_deltaz + off_hq * stride_deltah + q_start * stride_deltam + + if DROPOUT: + batch_philox_offset = philox_offset_base + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + dropout_offset = Dropout_mask + off_z * stride_dropoutz + off_hq * stride_dropouth #+ q_start * stride_dropoutm + else: + batch_philox_offset = 0 + dropout_offset = 0 + + if IS_FP8: + stride_descale_q_z = HQ + stride_descale_kv_z = HK + descale_q = tl.load(DESCALE_q + off_z * stride_descale_q_z + off_hq) + descale_k = tl.load(DESCALE_k + off_z * stride_descale_kv_z + off_hk) + descale_v = tl.load(DESCALE_v + off_z * stride_descale_kv_z + off_hk) + descale_do = tl.load(DESCALE_do + off_z * stride_descale_q_z + off_hq) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + # output tensor offsets - dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn - dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn + dk_offset = DK + off_z * stride_kz + off_hk * stride_kh + k_start * stride_kn + dv_offset = DV + off_z * stride_vz + off_hk * stride_vh + k_start * stride_vn if SEQUENCE_PARALLEL: - dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm else: - dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + dq_offset = DQ + off_z * stride_qz + off_hq * stride_qh + q_start * stride_qm # inner loop if SEQUENCE_PARALLEL: @@ -327,7 +458,7 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, q_offset, k_offset, v_offset, @@ -335,8 +466,9 @@ def _bwd_kernel( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -350,26 +482,33 @@ def _bwd_kernel( stride_vh, stride_vn, stride_vk, - stride_deltaz, - stride_deltah, + stride_deltaz, + stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) else: for start_n in range(0, num_block_n): @@ -384,7 +523,7 @@ def _bwd_kernel( DK, DV, L, - D, + Delta, q_offset, k_offset, v_offset, @@ -392,8 +531,9 @@ def _bwd_kernel( dq_offset, dk_offset, dv_offset, - d_offset, l_offset, + delta_offset, + dropout_offset, stride_dq_all, stride_qz, stride_qh, @@ -407,54 +547,69 @@ def _bwd_kernel( stride_vh, stride_vn, stride_vk, - stride_deltaz, - stride_deltah, + stride_deltaz, + stride_deltah, stride_deltam, - Z, - H, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, N_CTX_Q, N_CTX_K, - off_h, - off_z, - off_hz, start_n, num_block_m, num_block_n, + dropout_p, + philox_seed, + batch_philox_offset, + descale_q, + descale_k, + descale_v, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + DROPOUT=DROPOUT, USE_EXP2=USE_EXP2, + GROUP_SIZE=GROUP_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) -# NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom. +# NOTE: smaller blocks have lower accuracy. more accumulation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumulation errors but no oom. def attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, sm_scale: float, - alibi_slopes, - causal, - layout: str, - cu_seqlens_q, - cu_seqlens_k, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], max_seqlen_q: int, max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], use_exp2: bool, - sequence_parallel = True, + sequence_parallel: bool = True, + # fp8 + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, ): if DEBUG: print() - print("attention_prefill_backward_triton_new_impl") + print("attention_prefill_backward_triton_impl") print("do:", do, do.shape) print("q:", q, q.shape) print("k:", k, k.shape) @@ -472,24 +627,38 @@ def attention_prefill_backward_triton_impl( print("cu_seqlens_k:", cu_seqlens_k) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) print("sequence_parallel:", sequence_parallel) + print("descale_q:", descale_q) + print("descale_k:", descale_k) + print("descale_v:", descale_v) + print("descale_do:", descale_do) + + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX=torch.finfo(q.dtype).max + else: + FP8_MAX=None - # make contigious + # make contiguous q = q.contiguous() k = k.contiguous() v = v.contiguous() softmax_lse = softmax_lse.contiguous() # get strides and shape - batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) stride_qz, stride_qh, stride_qm, stride_qk = q_strides stride_kz, stride_kh, stride_kn, stride_kk = k_strides stride_vz, stride_vh, stride_vn, stride_vk = v_strides stride_oz, stride_oh, stride_om, stride_ok = o_strides - batch_headsize = batch * nheads_q is_varlen = layout == "thd" + group_size = nheads_q // nheads_k + use_dropout = (dropout_p > 0.0) # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks if max_seqlen_q <= 32 or max_seqlen_k <= 32: @@ -498,7 +667,12 @@ def attention_prefill_backward_triton_impl( else: BLOCK_M = 64 BLOCK_N = 64 - num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful + + if DEBUG: + print("BLOCK_M:", BLOCK_M) + print("BLOCK_N:", BLOCK_N) + + num_warps = 4 # NOTE: original is 8. changing it to 1 caused issues be careful num_stages = 1 waves_per_eu = 1 @@ -513,48 +687,13 @@ def attention_prefill_backward_triton_impl( ACTUAL_BLOCK_DMODEL = head_size do = do.contiguous() - # NOTE: we might need to copy the output tensor if they are not continuous or have other issues - copy_back = {"dq": False, "dk": False, "dv": False} # deal with dq - if dq is None: - if sequence_parallel: - dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) - else: - dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype) - else: - dq_og = dq - if (not dq.is_contiguous()): - dq = dq.contiguous() - copy_back["dq"] = True - - if sequence_parallel: - dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) - copy_back["dq"] = True - else: - # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros - dq.zero_() + if sequence_parallel: + dq = dq.unsqueeze(0).repeat(num_blocks_n, *([1] * len(q.shape))) # we do repeat instead of expand because we need to write data so views are not enough stride_dq_all = dq.stride()[0] - # deal with dk, dv - if (dk is None) or (dv is None): - dk = torch.empty_like(k) - dv = torch.empty_like(v) - else: - if (not dk.is_contiguous()): - dk_og = dk - dk = dk.contiguous() - copy_back["dk"] = True - - if (not dv.is_contiguous()): - dv_og = dv - dv = dv.contiguous() - copy_back["dv"] = True - - if DEBUG: - print("copy_back:", copy_back) - - # assert contigious + # assert contiguous assert do.is_contiguous() assert q.is_contiguous() assert k.is_contiguous() @@ -563,66 +702,53 @@ def attention_prefill_backward_triton_impl( assert softmax_lse.is_contiguous() # init delta - delta = torch.empty_like(softmax_lse) + delta = torch.zeros_like(softmax_lse) if is_varlen: stride_deltam, stride_deltah = delta.stride() stride_deltaz = 0 else: stride_deltaz, stride_deltah, stride_deltam = delta.stride() - _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + # dropout mask tensor for debugging. We dump the dropout mask created in the kernel for testing + if use_dropout: + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlen_q, max_seqlen_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlen_q, max_seqlen_k), device=q.device, + dtype=torch.float32) + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (dropout_mask.stride(0), dropout_mask.stride(1), dropout_mask.stride(2), dropout_mask.stride(3)) + else: + dropout_mask = None + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0 , 0 , 0) + + + _bwd_preprocess[(batch * nheads_q, num_blocks_m)]( o, do, delta, stride_oz, stride_oh, stride_om, stride_ok, - stride_oz, stride_oh, stride_om, stride_ok, + stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues stride_deltaz, stride_deltah, stride_deltam, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + descale_do, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, N_CTX_Q=max_seqlen_q, Z=batch, H=nheads_q, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, + IS_FP8=IS_FP8 ) if DEBUG: - print("_bwd_kernel inputs") - print("do:", do, do.shape) - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale", sm_scale) - print("o:", o, o.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("L:", softmax_lse, softmax_lse.shape) print("delta:", delta, delta.shape) - print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk) - print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk) - print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk) - print("batch_q:", batch) - print("heads_q:",nheads_q) - print("max_seqlen_q:",max_seqlen_q) - print("max_seqlen_k:",max_seqlen_k) - print("BLOCK_M:",BLOCK_M) - print("BLOCK_N:",BLOCK_M) - print("BLOCK_DMODEL:",BLOCK_DMODEL) - print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL) - print("SEQUENCE_PARALLEL:",sequence_parallel) - print("CAUSAL:",causal) - print("num_warps:",num_warps) - print("num_stages:", num_stages) - print("USE_EXP2:", use_exp2) - print("num_blocks_m:", num_blocks_m) - print("num_blocks_n:", num_blocks_n) - - _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)]( + print("group_size:", group_size) + + _bwd_kernel[(batch * nheads_q, num_blocks_n if sequence_parallel else 1)]( q, k, v, @@ -634,58 +760,55 @@ def attention_prefill_backward_triton_impl( dv, softmax_lse, delta, + dropout_mask, + descale_q, + descale_k, + descale_v, + descale_do, stride_dq_all, - stride_qz, stride_qh, stride_qm, stride_qk, + stride_qz, stride_qh, stride_qm, stride_qk, # FIXME: don't share strides with derivatives this was causing a lot of issues stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_deltaz, stride_deltah, stride_deltam, + stride_dropoutz, stride_dropouth, stride_dropoutm, stride_dropoutn, batch, nheads_q, + nheads_k, num_blocks_m, num_blocks_n, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, philox_seed, philox_offset, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel, CAUSAL=causal, + DROPOUT=use_dropout, USE_EXP2=use_exp2, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, - IS_VARLEN=is_varlen + IS_VARLEN=is_varlen, + GROUP_SIZE=group_size, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX ) - if DEBUG: - print("_bwd_kernel outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape) - if sequence_parallel: dq = dq.sum(dim=0) if DEBUG: - print("attention_prefill_backward_triton_new_impl outputs") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) + print("attention_prefill_backward_triton_impl outputs") print("dv:", dv, dv.shape) - print("delta:", delta, delta.shape) - print("copy_back:", copy_back) - - if copy_back["dq"]: - dq_og.copy_(dq) - dq = dq_og - if copy_back["dk"]: - dk_og.copy_(dk) - dk = dk_og - if copy_back["dv"]: - dv_og.copy_(dv) - dv = dv_og - - return dq, dk, dv, delta, None, None + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + if use_dropout: + print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None) + print("dropout_fraction bwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item()) + write_dropout_mask(dropout_mask, "dropout_mask_bwd") + + return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py new file mode 100644 index 0000000000..3c018be4fa --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused.py @@ -0,0 +1,3266 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Tuple + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + # compute fp8 scaling and descaling factor for a block + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + +def is_fp8(x): + if x.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: + if arch_supports_fp8(): + return True + else: + raise RuntimeError("This device does not support fp8") + else: + return False + + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype, + layout, + clamp_val=1e-9, +): + if len(x.shape) != 4: + raise ValueError(f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}") + reduce_dims = (1, 3) # seq_len and dim dimensions + + # Compute the absolute max along reduce_dims, clamped to avoid 0-scale + x_abs_max = x.abs().amax(dim=reduce_dims) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Unsqueeze back to a shape suitable for broadcast + unsqueeze_dims = sorted(reduce_dims) + for d in unsqueeze_dims: + x_abs_max = x_abs_max.unsqueeze(d) + + # compute scale and descale + fp8_max = torch.finfo(fp8_dtype).max + scale = fp8_max / x_abs_max + descale_factor = x_abs_max / fp8_max + + # cast to FP8, optionally setting requires_grad + x_fp8 = (x * scale).to(fp8_dtype) + + return x_fp8, descale_factor + + +def cast_varlen_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + cu_seqlens, + clamp_val: float = 1e-9, +) -> tuple[torch.Tensor, torch.Tensor]: + # validate tensor shape + if len(x.shape) != 3: + raise ValueError(f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}") + num_heads = x.shape[1] + + # Get batch size from cu_seqlens + batch = cu_seqlens.shape[0] - 1 + fp8_max = torch.finfo(fp8_dtype).max + + # Compute scale and descale factors per sequence + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) + + for i in range(batch): + start = cu_seqlens[i] + end = cu_seqlens[i + 1] + x_slice = x[start:end] # Slice for current sequence + + # Standard tensor (0: seq_len, 2: head_dim) + x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] + + # apply minimum clamping + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # compute scale and descale factors + scale_i = fp8_max / x_abs_max + descale_i = x_abs_max / fp8_max + + # store descale factors + descale_factors[i, :] = descale_i + + scale_reshape = scale_i.reshape(1, num_heads, 1) + + # scale and cast to FP8 + x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) + + return x_fp8, descale_factors + + +#TODO Move this to a common folder. Will need to add future arch list +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') + +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vk, + stride_sn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + sd_mask_ptrs, + dropout_mask_ptrs, + philox_seed, + philox_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + descale_q, + descale_k, + descale_v, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + SM_SCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + PADDED_HEAD: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # compute masks + q_mask = (OFFS_M[:, None] < seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < seqlen_k) + p_mask = q_mask & k_mask + + # -- compute qk ---- + if IS_FP8: + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, global_m_positions, + global_n_positions) + qk_scaled += alibi_block + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = qk_scaled - m_ij[:, None] + + # Compute scaled QK and softmax probabilities + p = tl.math.exp2(q_shifted * RCP_LN2) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + tl.store(sd_mask_ptrs, p, mask=p_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = m_i - m_ij + alpha = tl.math.exp2(m_diff * RCP_LN2) + acc = acc * alpha[:, None] + v = load_fn(v_ptrs, k_offs_n, k_offs_k, seqlen_k, BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if RETURN_SCORES: + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn + + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd(q_ptr: torch.Tensor, + k_ptr: torch.Tensor, + v_ptr: torch.Tensor, + descale_q_ptr: torch.Tensor, + descale_k_ptr: torch.Tensor, + descale_v_ptr: torch.Tensor, + out_ptr: torch.Tensor, + alibi_slopes_ptr: torch.Tensor, + s_dmask_ptr: torch.Tensor, + dropout_mask_ptr: torch.Tensor, + softmax_lse_ptr: torch.Tensor, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, + stride_oz, stride_oh, stride_om, stride_on, + stride_alibi_z, stride_alibi_h, + stride_sd_z, stride_sd_h, stride_sd_m, stride_sd_n, + stride_lse_z, stride_lse_h, stride_lse_m, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset, + SEQLEN_Q: tl.constexpr, + SEQLEN_K: tl.constexpr, + IS_CAUSAL: tl.constexpr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_POW2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + VARLEN: tl.constexpr, +): + #calculate offsets + start_m = tl.program_id(0) #seqlen_q + off_q_head = tl.program_id(1) #num_q_heads + off_z = tl.program_id(2) #batch + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_POW2) + + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = SEQLEN_Q + seqlen_k = SEQLEN_K + + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + offs_out = (off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=out_ptr.type.element_ty) + out_mask = (offs_m[:, None] < seqlen_q) & (offs_d < BLOCK_DMODEL) + tl.store(out_ptr + offs_out, acc, mask=out_mask) + + if softmax_lse_ptr is not None: + offs_lse = (off_z * stride_lse_z + + off_q_head * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m*stride_lse_m + ) + lse_mask = offs_m < SEQLEN_Q + lse = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) + tl.store(softmax_lse_ptr + offs_lse, lse, mask=lse_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + + return + + grp_sz:tl.constexpr = NUM_Q_HEADS // NUM_K_HEADS + if grp_sz != 1: #Grouped Query Attention + off_k_head = off_q_head // grp_sz + else: + off_k_head = off_q_head + + #q,k,v offsets + q_offs = (off_z * stride_qz + + off_q_head * stride_qh + + cu_seqlens_q_start * stride_qm + + offs_m[:, None] * stride_qm + offs_d[None, :]*stride_qk + ) + q_ptrs = q_ptr + q_offs + + k_offs = (off_z * stride_kz + + off_k_head * stride_kh + + cu_seqlens_k_start * stride_kn + + offs_d[:, None] * stride_kk + offs_n[None, :]*stride_kn + ) + k_ptrs = k_ptr + k_offs + + v_offs = (off_z * stride_vz + + off_k_head * stride_vh + + cu_seqlens_k_start * stride_vn + + offs_n[:, None] * stride_vn + offs_d[None, :]*stride_vk + ) + v_ptrs = v_ptr + v_offs + + #alibi slopes + if alibi_slopes_ptr is not None: + alibi_offs = off_z * stride_alibi_z + off_q_head * stride_alibi_h + alibi_slope = tl.load(alibi_slopes + alibi_offs) + else: + alibi_slope = None + + #s_dmask (return_scores) + if s_dmask_ptr is not None: + s_dmask_offs = (off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + s_dmask_ptrs = s_dmask_ptr + s_dmask_offs + else: + s_dmask_ptrs = None + + #dropout + if dropout_mask_ptr is not None: + dropout_mask_offs = (off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + dropout_mask_ptrs = dropout_mask_ptr + dropout_mask_offs + philox_ptrs = (philox_offset + + off_z * stride_sd_z + + off_q_head * stride_sd_h + + offs_m[:, None] * stride_sd_m + + offs_n[None, :] * stride_sd_n + ) + else: + dropout_mask_ptrs = None + philox_ptrs = None + + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32) + if (BLOCK_DMODEL == BLOCK_DMODEL_POW2): + q_mask = (offs_m[:, None] < seqlen_q) + else: + q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + if IS_FP8: + descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) + descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) + descale_v = tl.load(descale_v_ptr + off_z * stride_descale_v_z + off_k_head) + else: + descale_q, descale_k ,descale_v = 1.0, 1.0, 1.0 + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N -seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + + #if CAUSAL, then determine masked_blocks and full blocks + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, + stride_vn, + stride_sd_n, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, + block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, + offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, + sm_scale, False, MASK_STEPS=False, ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vn + if RETURN_SCORES: + s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n + acc, l_i, m_i = _attn_fwd_inner(acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + stride_kn, stride_vn, stride_sd_n, + start_m, seqlen_k, seqlen_q, + dropout_p, + s_dmask_ptrs, dropout_mask_ptrs, philox_seed, philox_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, + offs_m, offs_n, BLOCK_M, BLOCK_N, BLOCK_DMODEL,BLOCK_DMODEL_POW2, + sm_scale, IS_CAUSAL, MASK_STEPS=True, ENABLE_DROPOUT=ENABLE_DROPOUT, + RETURN_SCORES=RETURN_SCORES, PADDED_HEAD=BLOCK_DMODEL!=BLOCK_DMODEL_POW2, + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX + ) + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL_POW2, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + overflow_size = end_m_idx - seqlen_q + if softmax_lse_ptr is not None: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + mi_base2 = m_i * RCP_LN2 + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + + if IS_CAUSAL: + # zero out nans caused by -infs when doing causal + lse_causal_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx + softmax_lse = tl.where(lse_causal_mask, 0.0, softmax_lse) + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + offs_lse = off_z * stride_lse_z + off_q_head * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m*stride_lse_m + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + lse_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(softmax_lse_ptr + offs_lse, softmax_lse, mask=lse_mask) # the log of the normalization constant + else: + tl.store(softmax_lse_ptr + offs_lse, softmax_lse) # the log of the normalization constant + + # write back O + offs_out = (off_z * stride_oz + + off_q_head * stride_oh + + cu_seqlens_q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + out_mask = tl.full([BLOCK_M, BLOCK_DMODEL_POW2], 1, dtype=tl.int1) + if overflow_size > 0: + out_mask = out_mask & (offs_m[:, None] < seqlen_q) + if BLOCK_DMODEL != BLOCK_DMODEL_POW2: + out_mask = out_mask & (offs_d[None, :] < BLOCK_DMODEL) + op = acc.to(out_ptr.dtype.element_ty) + tl.store(out_ptr + offs_out, op, mask=out_mask) + +def _flash_attn_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + return_lse: bool, + return_softmax: bool, + max_seqlen_q: int, + max_seqlen_k: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + #FP8 + IS_FP8 = is_fp8(q) + FP8_MAX: tl.constexpr=torch.finfo(q.dtype).max + is_varlen = True if cu_seqlens_q is not None else False + + if IS_FP8: + o = torch.zeros_like(q, dtype=torch.float32) + else: + o = torch.zeros_like(q) + if is_varlen: + #Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k = k.shape[1] + num_k_heads = k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + #padding for head_dim. Power of 2 or 16 + BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) + + #softmax_lse [batch, num_q_heads, seqlen_q] + if return_lse: + if is_varlen: + softmax_lse = torch.zeros((q.shape[0], num_q_heads), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = 0, softmax_lse.stride(1), softmax_lse.stride(0) + else: + softmax_lse = torch.zeros((batch, num_q_heads, max_seqlen_q), device=q.device, dtype=torch.float32) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + else: + softmax_lse = None + + #exp_scores [batch, num_q_heads, seqlen_q, seqlen_k] + enable_dropout = dropout_p > 0.0 + if enable_dropout: + philox_seed = torch.randint(0, 0xffffff, (1,))[0].item() #No specific reason to restrict range to 0xffffff + philox_offset = torch.randint(0, 0xffffff, (1,))[0].item() #Pass in an int, not Tensor + else: + philox_seed = 0 + philox_offset = 0 + if return_softmax or enable_dropout: + s_dmask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) + dropout_mask = torch.zeros((batch, num_q_heads, max_seqlen_q, max_seqlen_k), device=q.device, dtype=torch.float32) + else: + s_dmask = None + dropout_mask = None + + + # Best config from ROCm/triton/python/perf-kernels/flash_attention.py::attn_fwd autotuning is BLOCK_M: 128, BLOCK_N: 64, waves_per_eu: 2, num_warps: 4, num_ctas: 1, num_stages: 1 + # Tuned for MI300x + config = { + 'BLOCK_M': 128, + 'BLOCK_N': 32, # BLOCK_N: 64 spills for _attn_fwd + 'waves_per_eu': 2, + 'num_warps': 4, + 'num_ctas': 1, + 'num_stages': 1, + } + + grid = lambda META:(triton.cdiv(seqlen_q, META['BLOCK_M']), num_q_heads, batch) + _attn_fwd[grid](q, + k, + v, + descale_q, + descale_k, + descale_v, + o, + alibi_slopes, + s_dmask, + dropout_mask, + softmax_lse, + *q_strides, + *k_strides, + *v_strides, + descale_q.stride(0) if descale_q is not None else 0, + descale_k.stride(0) if descale_k is not None else 0, + descale_v.stride(0) if descale_v is not None else 0, + *o_strides, + alibi_slopes.stride(0) if alibi_slopes is not None else 0, + alibi_slopes.stride(1) if alibi_slopes is not None else 0, + s_dmask.stride(0) if s_dmask is not None else 0, + s_dmask.stride(1) if s_dmask is not None else 0, + s_dmask.stride(2) if s_dmask is not None else 0, + s_dmask.stride(3) if s_dmask is not None else 0, + stride_lse_z if softmax_lse is not None else 0, + stride_lse_h if softmax_lse is not None else 0, + stride_lse_m if softmax_lse is not None else 0, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset, + SEQLEN_Q=max_seqlen_q, + SEQLEN_K=max_seqlen_k, + IS_CAUSAL=causal, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_DMODEL=head_sz, + BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, + RETURN_SCORES=return_softmax, + ENABLE_DROPOUT=enable_dropout, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + VARLEN=is_varlen, + **config + ) + + return o, softmax_lse, s_dmask, philox_seed, philox_offset + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +@triton.jit +def _bwd_preprocess( + o_ptr, do_ptr, # noqa: E741 + delta_ptr, + stride_o_b, stride_o_h, stride_o_m, stride_o_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + descale_do_ptr, + BLOCK_M: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) #seqlen + bid = tl.program_id(1) #batch + hid = tl.program_id(2) #head + + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # Offset O/DO by batch, head and q_start + offs = (bid * stride_o_b + + hid * stride_o_h + + q_start * stride_o_m + offs_m[:, None] * stride_o_m + + offs_k[None, :] * stride_o_k) + + # create masks + mask_m = offs_m < seqlen_q + mask = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D_MODEL + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs, mask=mask, other=0.0) + do = tl.load(do_ptr + offs, mask=mask, other=0.0) + + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + + offs_delta = (bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + offs_m * stride_delta_m) + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + +@triton.jit +def _bwd_dq_inner( + dq, + q, K, V, do, m, Delta, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropout_m, stride_dropout_n, + stride_deltam, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + + curr_n = start_n + step_n = BLOCK_N + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < BLOCK_D_MODEL + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + philox_offs = (curr_philox_offset + + offs_m[:, None] * stride_dropout_m + + offs_n[None, :] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + #qk + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask * mask_mn + p = tl.where(mask, p, 0.0) + + #dp + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + + #ds + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + #dq + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds*scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def _bwd_dkdv_inner( + dk, dv, + Q, k, v, DO, M, D, sm_scale, + stride_q_m, stride_q_k, + stride_do_m, stride_do_k, + stride_dropout_m, stride_dropout_n, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + qT_ptrs = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] + do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + #Iterate over blocks(BLOCK_M size) of Q while calculating + #a fixed block(BLOCK_N) of dk and dv. Note, during backward + #pass P has to be recomputed. However, this kernel computes + #dV and dK, so we compute we need P^T and S^T. See backward pass + #equations + # + #From Flash Attention Paper: + #ForwardPass: S = QkT, P=softmax(S), O=PV + # + #BackwardPass equations + #dV = P^TdO + #dP = dOV^T + #dS = dsoftmax(dP) + #dQ = dSK + #dK = QdS^T + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + #load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = (curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + #Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + #Compute qkT + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + + #Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + #load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + #dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + #Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + #Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + #compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + #increment pointers + curr_m += step_m + qT_ptrs += step_m * stride_q_m + do_ptrs += step_m * stride_do_m + + return dk, dv + + +@triton.jit +def _bwd_dkdvdq_inner( + dk, dv, + Q, k, v, DO, DQ, M, D, sm_scale, + stride_q_m, stride_q_k, + stride_do_m, stride_do_k, + stride_dropout_m, stride_dropout_n, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id: tl.int32, +): + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k #[BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = DQ + offs_m[:, None] * stride_q_m + offs_k[None,:] * stride_q_k #[BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None,: ] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + #Iterate over blocks(BLOCK_M size) of Q while calculating + #a fixed block(BLOCK_N) of dk and dv. Note, during backward + #pass P has to be recomputed. However, this kernel computes + #dV and dK, so we compute we need P^T and S^T. See backward pass + #equations + # + #From Flash Attention Paper: + #ForwardPass: S = QkT, P=softmax(S), O=PV + # + #BackwardPass equations + #dV = P^TdO + #dP = dOV^T + #dS = dsoftmax(dP) + #dQ = dSK + #dK = QdS^T + + # Compute a starting index and step based on workgroup_id + # Use a simple hash-like function to spread out the starting points + start_idx = (workgroup_id * 17) % num_steps # 17 is an arbitrary prime to spread indices + # Ensure step is coprime with num_steps to visit all indices exactly once + step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps + + + for iter in range(num_steps): + # Compute the permuted block index + blk_idx = (start_idx + iter * step) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_q_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + #load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + #dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = (curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + #Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + #Compute qkT + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + + #Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + #load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + #dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do) * descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + #Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + #Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do)) * descale_v * descale_do + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + #compute dk + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = tl.dot((dsT * scale_dsT).to(k.dtype).T, k) * descale_dsT * descale_k + else: + dq_partial = tl.dot(dsT.to(k.dtype).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None], + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_dkdvdq_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, dq_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + batch_idx = wid % BATCH + head_k_idx = wid // BATCH % NUM_K_HEADS + seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + #Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k) + adj_v = (batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_q + + adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + do_ptr_adj = do_ptr + adj_do + adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + + # when q < k, we may skip the initial masked op + # if seq_k_blk_idx < num_blocks_skip: + # num_steps = 0 + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdvdq_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK_BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, dq_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = (batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dkdv_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dk_ptr, dv_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dk_b, stride_dk_h, stride_dk_n, stride_dk_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + #seq block, batch, head_k + seq_k_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + #Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx *BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k) + adj_v = (batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k , mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + q_ptr_adj = q_ptr + adj_q + adj_do = batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + do_ptr_adj = do_ptr + adj_do + adj_delta = batch_idx * stride_delta_b + head_q_idx * stride_delta_h + q_start * stride_delta_m + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK_BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + q_ptr_adj, k, v, do_ptr_adj, m_ptr_adj, delta_ptr_adj, sm_scale, # input tensors + stride_q_m, stride_q_k, # strides for q + stride_do_m, stride_do_k, # strides for o + stride_dropout_m, stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + BLOCK_M, BLOCK_N, # block dim + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + # Write back dV and dK. + offs_dkdv = (batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + +@triton.jit +def _bwd_kernel_dq_causal( + q_ptr, k_ptr, v_ptr, sm_scale, do_ptr, dq_ptr, + m_ptr, delta_ptr, + stride_q_b, stride_q_h, stride_q_m, stride_q_k, + stride_k_b, stride_k_h, stride_k_n, stride_k_k, + stride_v_b, stride_v_h, stride_v_n, stride_v_k, + stride_dq_b, stride_dq_h, stride_dq_m, stride_dq_k, + stride_delta_b, stride_delta_h, stride_delta_m, + stride_do_b, stride_do_h, stride_do_m, stride_do_k, + stride_dropout_b, stride_dropout_h, stride_dropout_m, stride_dropout_n, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + seq_q_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = seq_q_blk_idx * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if start_m + BLOCK_M < delta_qk: + return + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k + offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n + adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n + k_ptr_adj = k_ptr + v_ptr_adj = v_ptr + k_ptr_adj += adj_k + v_ptr_adj += adj_v + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for head_q_idx in range(head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + + # offset input and output tensor by batch and Q/K heads + adj_q = (batch_idx * stride_q_b + + head_q_idx * stride_q_h + + q_start * stride_q_m) + adj_do = (batch_idx * stride_do_b + + head_q_idx * stride_do_h + + q_start * stride_do_m) + adj_delta = (batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m) + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + dropout_offset = (dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h) + + q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + batch_idx * stride_descale_q_z + head_q_idx) + descale_k = tl.load(descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx) + descale_v = tl.load(descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx) + descale_do = tl.load(descale_do_ptr + batch_idx * stride_descale_do_z + head_q_idx) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + dq = _bwd_dq_inner( + dq, + q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, + stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, + stride_dropout_m, stride_dropout_n, + stride_delta_m, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, MASK_BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + dq = _bwd_dq_inner( + dq, + q, k_ptr_adj, v_ptr_adj, do, m, delta_ptr_adj, sm_scale, + stride_q_m, stride_q_k, stride_k_n, stride_k_k, stride_v_n, stride_v_k, + stride_dropout_m, stride_dropout_n, + stride_delta_m, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # Write back dQ. + offs_dq = (batch_idx * stride_dq_b + + head_q_idx * stride_dq_h + + q_start * stride_dq_m + + offs_m[:, None] * stride_dq_m + + offs_k[None, :] * stride_dq_k) + dq *= sm_scale + tl.store(dq_ptr + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_dkdvdq_noncausal( + Q, K, V, sm_scale, DO, DK, DV, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk) + adj_v = (bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_q + + adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) + DO_ptr = DO + adj_do + adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + #dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner( + dk, dv, + Q_ptr, k, v, DO_ptr, DQ_ptr, M_ptr, Delta_ptr, sm_scale, + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=pid, + ) + + adj_dkdv = (bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = (bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk) + adj_v = (bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = (bid * stride_qb + hqid * stride_qh + q_start * stride_qm) + Q_ptr = Q + adj_q + adj_do = (bid * stride_dob + hqid * stride_doh + q_start * stride_dom) + DO_ptr = DO + adj_do + adj_delta = (bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam) + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + #dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + seqlen_q, seqlen_k, + start_n, start_m, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dkdv = (bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, DO, DQ, + M, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + descale_q_ptr, descale_k_ptr, descale_v_ptr, descale_do_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) #seqlen + bid = tl.program_id(1) #batch + hkid = tl.program_id(2) #head_k + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + + #mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (BLOCK_D_MODEL != BLOCK_D_MODEL_POW2) + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + delta_ptr = delta + adj_delta + + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = (philox_offset_base + + bid * stride_dropoutb + + hqid * stride_dropouth) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + #FP8 + if IS_FP8: + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hqid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + descale_do = tl.load(descale_do_ptr + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M, BLOCK_N, + BLOCK_D_MODEL, BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + +def _flash_attn_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int] = 0, + philox_offset: Optional[int] = 0, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + fused: bool = False, +): + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) + else: + FP8_MAX = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None + descale_strides = (stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z) + + IS_VARLEN = True if cu_seqlens_q is not None else False + + #get strides and shape + if IS_VARLEN: + #Layout for q,k,v is thd ie [total tokens, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + dq_strides = (0, dq.stride(1), dq.stride(0), dq.stride(2)) + dk_strides = (0, dk.stride(1), dk.stride(0), dk.stride(2)) + dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) + do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) + else: + #Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] + batch, seqlen_q, num_q_heads, head_sz = q.shape + seqlen_k, num_k_heads = k.shape[1], k.shape[2] + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + dq_strides = (dq.stride(0), dq.stride(2), dq.stride(1), dq.stride(3)) + dk_strides = (dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(3)) + dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) + do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + + #BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 + #padding for head_dim. Power of 2 or 16 + BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) + BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + + #Configs + #PRE_BLOCK, BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 + #BLK_SLICE_FACTOR + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + #BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + + #init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + #[total_tokens, num_q_heads, seqlen_q] + delta_strides = (0, delta.stride(1), delta.stride(0)) + else: + #[batch, num_q_heads, seqlen_q] + delta_strides = delta.stride() + + #preprocess + #compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + pre_grid = (triton.cdiv(max_seqlen_q, PRE_BLOCK), batch, num_q_heads) + _bwd_preprocess[pre_grid]( + o, do, + delta, + *o_strides, + *delta_strides, + descale_strides[3], + cu_seqlens_q, max_seqlen_q, + descale_do, + BLOCK_M=PRE_BLOCK, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8 + ) + + #dropout_mask + use_dropout = (dropout_p > 0.0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, num_q_heads, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32) + dropout_strides = dropout_mask.stride() + else: + dropout_mask = None + dropout_strides = (0, 0, 0, 0) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, num_k_heads) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, num_k_heads) + + if fused: # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + + BLOCK_N = 128 + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } + + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * num_k_heads * num_k_pids,) + + if causal: + _bwd_kernel_dkdvdq_causal[grid_dkdvdq]( + q, k, v, sm_scale, do, dk, dv, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_dkdvdq_noncausal[grid_dkdvdq]( + q, k, v, sm_scale, do, dk, dv, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + + return delta + + # split kernels solution: one kernel computes dk, dv and the other computes dq + + if causal: + _bwd_kernel_dkdv_causal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_dq_causal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dk_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + *q_strides, + *k_strides, + *v_strides, + *dq_strides, + *delta_strides, + *do_strides, + *dropout_strides, + *descale_strides, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask,dropout_p, philox_seed, philox_offset, + descale_q, descale_k, descale_v, descale_do, + NUM_Q_HEADS=num_q_heads, + NUM_K_HEADS=num_k_heads, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + return delta + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + fused_backward, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q,k,v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + ) + + if is_grad: + ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.fused_backward = fused_backward + + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + _flash_attn_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + fused=ctx.fused_backward, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1,-1), + alibi_slopes=None, + deterministic=True, + return_lse=False, + return_attn_probs=False, + fused_backward=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled(), + fused_backward, + ) + + +class FlashAttnFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q,k,v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = torch.float8_e4m3fnuz + q_fp8, descale_q = cast_to_fp8(q, fp8_dtype, "bshd") + k_fp8, descale_k = cast_to_fp8(k, fp8_dtype, "bshd") + v_fp8, descale_v = cast_to_fp8(v, fp8_dtype, "bshd") + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=q.shape[1], + max_seqlen_k=k.shape[1], + cu_seqlens_q=None, + cu_seqlens_k=None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v + ) + + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, descale_q, descale_k, descale_v) + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q_fp8, dtype=torch.float32), torch.zeros_like(k_fp8, dtype=torch.float32), torch.zeros_like(v_fp8, dtype=torch.float32) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = torch.float8_e4m3fnuz + do_padded_fp8, descale_do = cast_to_fp8(do_padded, fp8_dtype, "bshd") + _flash_attn_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + None, + None, + max_seqlen_q=q_fp8.shape[1], + max_seqlen_k=k_fp8.shape[1], + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do, + ) + #dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension + #dk = dk[..., : k_fp8.shape[-1]] + #dv = dv[..., : v_fp8.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None + +def flash_attn_fp8_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False +): + return FlashAttnFP8Func.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + torch.is_grad_enabled() + ) + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + is_grad_enabled, + fused_backward, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0.0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if is_grad: + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.fused_backward = fused_backward + out = out_padded[..., :head_size_og] + + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = do.size(2) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + _flash_attn_backward( + do_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_k=ctx.max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + fused=ctx.fused_backward, + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1,-1), + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None, + fused_backward=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + torch.is_grad_enabled(), + fused_backward, + ) + + +class FlashAttnVarlenFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_softmax, + block_table, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # cast input to fp8 + fp8_dtype = torch.float8_e4m3fnuz + q_fp8, descale_q = cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) + k_fp8, descale_k = cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) + v_fp8, descale_v = cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) + + out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = _flash_attn_forward( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + alibi_slopes=alibi_slopes, + return_lse=return_lse, + return_softmax=return_softmax and dropout_p > 0, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, descale_q, descale_k, descale_v) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + out = out_padded[..., :head_size_og] + result = [out] + if return_lse: + result.append(softmax_lse) + if return_softmax: + result.append(S_dmask) + + return tuple(result) + + @staticmethod + def backward(ctx, do, *args): + q_fp8, k_fp8, v_fp8, out, softmax_lse, cu_seqlens_q, cu_seqlens_q, descale_q, descale_k, descale_v = ctx.saved_tensors + dq, dk, dv = torch.zeros_like(q, dtype=torch.float32), torch.zeros_like(k, dtype=torch.float32), torch.zeros_like(v, dtype=torch.float32) + head_size_v_og = do.size(3) + do_padded = do + if head_size_v_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) + + fp8_dtype = torch.float8_e4m3fnuz + do_padded_fp8, descale_do = cast_varlen_to_fp8(dout_padded, fp8_dtype, "thd", cu_seqlens_q) + + _flash_attn_backward( + do_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.alibi_slopes, + ctx.causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=ctx.dropout_p, + philox_seed=ctx.philox_seed, + philox_offset=ctx.philox_offset, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_do=descale_do + ) + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None + +def flash_attn_varlen_fp8_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_lse=False, + return_attn_probs=False, + block_table=None +): + return FlashAttnVarlenFP8Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_lse, + return_attn_probs, + block_table, + torch.is_grad_enabled() + ) \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py new file mode 100644 index 0000000000..3f650d288d --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_onekernel.py @@ -0,0 +1,1091 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from typing import Literal, Optional +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, get_shapes_from_layout, compute_fp8_scaling_factors, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_rdna + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + + +def get_autotune_configs(): + if False: + if is_cdna(): + # shared meta-parameters + NUM_STAGES = 1 + NUM_WARPS = 4 + WAVES_PER_EU = 2 + MATRIX_INSTR_NONKDIM = 16 + + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": 128, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({"PRE_BLOCK": 64, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({"PRE_BLOCK": 32, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({"PRE_BLOCK": 16, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + preprocess_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + causal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 16, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 16, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + noncausal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + else: + raise ValueError("Unknown Device Type") + else: + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + assert BLOCK_N1 == BLOCK_M2 + + # configs for the kernels + preprocess_autotune_configs = [ + triton.Config({"PRE_BLOCK": PRE_BLOCK, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + preprocess_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + causal_autotune_configs = [ + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + causal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + noncausal_autotune_configs = [ + triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), + ] + noncausal_autotune_keys = [ + "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", + "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + ] + return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) + + + +(preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) = get_autotune_configs() + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.autotune( + configs=preprocess_autotune_configs, + key=preprocess_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + Descale_do, + PRE_BLOCK: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_do + do_ptrs = DO + offs_do + # load + o = tl.load(out_ptrs, mask=mask_md, other=0.0) + do = tl.load(do_ptrs, mask=mask_md, other=0.0) + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_nm + ) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT * sm_scale) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT * sm_scale * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V. + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = (tl.dot(q, kT) * descale_q * descale_k) + else: + qk = tl.dot(q, kT) + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk * sm_scale) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk * sm_scale - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp -delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + +@triton.autotune( + configs=causal_autotune_configs, + key=causal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + \ + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}") # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, MASK_BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + +@triton.autotune( + configs=noncausal_autotune_configs, + key=noncausal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_noncausal( + Q, K, V, sm_scale, DO, DQ, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + dropout_mask, dropout_p, philox_seed, philox_offset_base, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_EXP2: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + offs_k = tl.arange(0, HEAD_DIM) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + offs_kv = offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + + # K/V tensors not changed for the group + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_kv + offs_kv, mask=mask_kv, other=0.0) + v = tl.load(V + adj_kv + offs_kv, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M1, BLOCK_N1, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + None, None, None, None, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_kv = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + K += adj_kv + V += adj_kv + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, # + q, K, V, do, m, Delta_ptr, sm_scale, # + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, # + stride_dropoutm, stride_dropoutn, # + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2, BLOCK_N2, # + HEAD_DIM, ACTUAL_HEAD_DIM, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + start_m, start_n, end_n, num_steps, # + None, None, None, None, + MASK=False, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_EXP2=USE_EXP2, + IS_FP8=False, + FP8_MAX=None, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_oneKernel_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool, +): + # debug + DEBUG_TRITON: bool = False + DEBUG_TRITON_DETAIL: bool = False + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ + get_shapes_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_vb, stride_vh, stride_vn, stride_vk = v_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, _, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + + # init delta + delta = torch.empty_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = lambda META: (triton.cdiv(max_seqlen_q_final, META['PRE_BLOCK']), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + 0, + cu_seqlens_q, max_seqlen_q_final, + None, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN, + IS_FP8=False + ) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + seqlen = max(max_seqlen_q_final, max_seqlen_k_final) + grid = lambda META: ((seqlen + META['BLOCK_N1'] - 1) // META['BLOCK_N1'], batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"bwd_kernel: grid = {grid}" ) # noqa: E701 + bwd_kernel_causal[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_noncausal[grid]( + q, k, v, sm_scale, do, dq, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_EXP2=use_exp2, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return delta \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py new file mode 100644 index 0000000000..c1e2ff5985 --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -0,0 +1,1354 @@ +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from typing import Literal, Optional +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ + get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 + +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q), same as softmax_lse defined at +# fwd_prefill.py line 607 +@triton.jit +def _bwd_preprocess( + O, DO, # noqa: E741 + Delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q, + Descale_do, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + q_start = 0 + seqlen_q = max_seqlen_q + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, HEAD_DIM) + # Offset O/DO by batch, head and q_start + O += bid * stride_ob + hid * stride_oh + q_start * stride_om # noqa: E741 + DO += bid * stride_ob + hid * stride_oh + q_start * stride_om + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_md &= offs_k[None, :] < ACTUAL_HEAD_DIM + # compute pointers + offs_do = offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok + out_ptrs = O + offs_do + do_ptrs = DO + offs_do + # load + o = tl.load(out_ptrs, mask=mask_md, other=0.0) + do = tl.load(do_ptrs, mask=mask_md, other=0.0) + # compute and write-back to delta + if IS_FP8: + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hid) + + # NOTE: do is in the fp8 range and o is not in fp8 + delta = tl.sum(o.to(tl.float32) * (do.to(tl.float32) * descale_do), axis=1) + else: + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + delta_offset = Delta + bid * stride_deltab + hid * stride_deltah + q_start * stride_deltam + tl.store(delta_offset + offs_m * stride_deltam, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, dv, # output + Q, k, v, DO, M, D, sm_scale, # input tensor + stride_qm, stride_qk, + stride_dom, stride_dok, + stride_dropoutm, stride_dropoutn, + stride_deltam, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM: tl.constexpr, # + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k = tl.arange(0, HEAD_DIM) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM + mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[None, :] * stride_dropoutm + \ + offs_n[:, None] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_nm + ) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + if IS_FP8: + qkT = (tl.dot(k, qT) * descale_q * descale_k) + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + # TODO: remove the scaling of m later when we removed re-scaling in fwd + if USE_EXP2: + pT = tl.math.exp2(qkT_scaled * RCP_LN2 - m[None, :] * RCP_LN2) + else: + pT = tl.math.exp(qkT_scaled - m[None, :]) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print(f"qkT after causal: {qkT.shape}\n", tl.where(causal_mask, qkT * sm_scale, 0.0)) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + if IS_FP8: + scale_p_dropout, descale_p_dropout = compute_fp8_scaling_factors(pT_dropout, FP8_MAX) + dv += (tl.dot((pT_dropout * scale_p_dropout).to(do.type.element_ty), do)* descale_p_dropout * descale_do) + else: + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + if IS_FP8: + scale_pT, descale_pT = compute_fp8_scaling_factors(pT, FP8_MAX) + dv += (tl.dot((pT * scale_pT).to(do.type.element_ty), do) * descale_pT * descale_do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + # Compute dP and dS. + if IS_FP8: + dpT = (tl.dot(v, tl.trans(do)) * descale_v * descale_do) + else: + dpT = tl.dot(v, tl.trans(do)) + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + if IS_FP8: + scale_dsT, descale_dsT = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += (tl.dot((dsT * scale_dsT).to(qT.type.element_ty), tl.trans(qT)) * descale_dsT * descale_q) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# grid = (max_seqlen_k // BLOCK_N, batch, nheads_q) +@triton.jit +def _bwd_kernel_dkdv_causal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") + if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: print(f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}") + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: print(f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}") + # align the delta_qk + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + + GROUP_SIZE = HQ // HK + # K/V tensors not changed for the group + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k , mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + if DEBUG_TRITON: print(f"residue_m = {residue_m}") + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: print(f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}") + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + MASK_BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + if DEBUG_TRITON: print(f"start_m after Masked step: {start_m}; num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print(f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}") # noqa: E701 + if DEBUG_TRITON: print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, K, V, do, m, Delta, sm_scale, # input + # shared by Q/K/V. + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, # stride for dropout + stride_deltam, + seqlen_q, seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, # + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, start_n, end_n, num_steps, # + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 + if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = curr_philox_offset + \ + offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + if tl_DROPOUT_USE_PYTORCH: + dropout_offs = offs_m[:, None] * stride_dropoutm + \ + offs_n[None, :] * stride_dropoutn + dropout_mask = tl.load( + curr_dropout_offset + dropout_offs, + mask=mask_mn) + else: + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = (tl.dot(q, kT) * descale_q * descale_k) + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + if USE_EXP2: + p = tl.math.exp2(qk_scaled * RCP_LN2 - m * RCP_LN2) + else: + p = tl.math.exp(qk_scaled - m) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + if IS_FP8: + dp = (tl.dot(do, vT) * descale_do * descale_v) + else: + dp = tl.dot(do, vT) + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp -delta_i) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + scale_ds, descale_ds = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += (tl.dot((ds * scale_ds).to(kT.type.element_ty), tl.trans(kT)) * descale_ds * descale_k) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +# grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nheads_q) +@triton.jit +def _bwd_kernel_dq_causal( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = pid * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: print(f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M}") # noqa: E701 + if start_m + BLOCK_M < delta_qk: + if DEBUG_TRITON: print(f"start_m + BLOCK_M = {start_m} + {BLOCK_M} = {start_m + BLOCK_M} < delta_qk of {delta_qk}") # noqa: E701 + return + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + if DEBUG_TRITON: print(f"pid: {pid}; end_n: {end_n}, start_m: {start_m}") # noqa: E701 + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + if DEBUG_TRITON: print(f"Masked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, MASK_BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + if DEBUG_TRITON: print(f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}") # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_dkdv_noncausal( + Q, K, V, sm_scale, DO, DK, DV, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, # 32 + BLOCK_N: tl.constexpr, # 128 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, HEAD_DIM) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_kv &= mask_k[None, :] + + GROUP_SIZE = HQ // HK + # K/V tensors not changed for the group + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = Dropout_mask + bid * stride_dropoutb + \ + hqid * stride_dropouth + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner( + dk, dv, # output tensors + Q_ptr, k, v, DO_ptr, M_ptr, Delta_ptr, sm_scale, # input tensors + stride_qm, stride_qk, # strides for q + stride_dom, stride_dok, # strides for o + stride_dropoutm, stride_dropoutn, # strides for dropout + stride_deltam, + BLOCK_M, BLOCK_N, # block dim + HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + dropout_p, philox_seed, batch_philox_offset, dropout_offset, # + alibi_slope, + seqlen_q, seqlen_k, # max sequence length for q and k + start_n, start_m, num_steps, # iteration numbers + descale_q, descale_k, descale_v, descale_do, # fp8 descale factors from user + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV and dK. + adj_dkdv = bid * stride_dkb + hkid * stride_kh + k_start * stride_dkn + offs_dkdv = offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk + tl.store(DV + adj_dkdv + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_dq_noncausal( + Q, K, V, sm_scale, DO, DQ, + M, Delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + HQ, HK, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + Dropout_mask, dropout_p, philox_seed, philox_offset_base, + Alibi_slopes, + Descale_q, Descale_k, Descale_v, Descale_do, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, HEAD_DIM) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + if PADDED_HEAD: + mask_k = offs_k < ACTUAL_HEAD_DIM + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = HQ // HK + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = \ + bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + Delta_ptr = Delta + adj_delta + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + \ + bid * stride_dropoutb + \ + hqid * stride_dropouth + dropout_offset = \ + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, + mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + descale_do = tl.load(Descale_do + bid * stride_descale_do_z + hqid) + else: + descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, K, V, do, m, Delta_ptr, sm_scale, + stride_qm, stride_qk, stride_kn, stride_kk, stride_vn, stride_vk, + stride_dropoutm, stride_dropoutn, + stride_deltam, + seqlen_q, seqlen_k, + BLOCK_M, BLOCK_N, + HEAD_DIM, ACTUAL_HEAD_DIM, + dropout_p, philox_seed, batch_philox_offset, dropout_offset, + alibi_slope, + start_m, start_n, end_n, num_steps, + descale_q, descale_k, descale_v, descale_do, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def attention_prefill_backward_triton_split_impl( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], + descale_do: Optional[torch.Tensor], + descale_dq: Optional[torch.Tensor], + descale_dk: Optional[torch.Tensor], + descale_dv: Optional[torch.Tensor], +): + # debug + DEBUG_TRITON: bool = False + DEBUG_TRITON_DETAIL: bool = False + + # fp8 + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + # assert that the main inputs are fp8 + assert is_fp8(do) and is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: do.dtype={do.dtype}, q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." + assert descale_dq is not None, f"descale_dq is None. In fp8, you need to pass a tensor for descale_dq along with a tensor dq." + assert descale_dk is not None, f"descale_dk is None. In fp8, you need to pass a tensor for descale_dk along with a tensor dk." + assert descale_dv is not None, f"descale_dv is None. In fp8, you need to pass a tensor for descale_dv along with a tensor dv." + else: + FP8_OUTPUT = False + + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = stride_descale_do_z = None + + + # get strides and shape + batch, nheads_q, nheads_k, head_size, max_seqlen_q_final, max_seqlen_k_final = \ + get_shapes_from_layout( + q, k, layout, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k + ) + q_strides, k_strides, v_strides, o_strides = \ + get_strides_from_layout(q, k, v, o, layout) + stride_qb, stride_qh, stride_qm, stride_qk = q_strides + stride_kb, stride_kh, stride_kn, stride_kk = k_strides + stride_vb, stride_vh, stride_vn, stride_vk = v_strides + stride_ob, stride_oh, stride_om, stride_ok = o_strides + dq_strides, dk_strides, dv_strides, do_strides = \ + get_strides_from_layout(dq, dk, dv, do, layout) + stride_dqb, stride_dqh, stride_dqm, stride_dqk = dq_strides + stride_dkb, stride_dkh, stride_dkn, stride_dkk = dk_strides + stride_dvb, stride_dvh, stride_dvn, stride_dvk = dv_strides + stride_dob, stride_doh, stride_dom, stride_dok = do_strides + IS_VARLEN = layout == "thd" + use_dropout = (dropout_p > 0.0) + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + + # get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 32) # NOTE: the causal path expects a min of 32. It will cause a compiler assert. + HEAD_DIM = padded_d_model + ACTUAL_HEAD_DIM = head_size + # meta-parameters + # TODO: fix num_stages later + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + PRE_BLOCK = 128 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + + # init delta + delta = torch.zeros_like(softmax_lse) + if IS_VARLEN: + stride_deltab = 0 + stride_deltam, stride_deltah = delta.stride() + else: + stride_deltab, stride_deltah, stride_deltam = delta.stride() + pre_grid = (triton.cdiv(max_seqlen_q_final, PRE_BLOCK), batch, nheads_q) + _bwd_preprocess[pre_grid]( + o, do, + delta, + stride_ob, stride_oh, stride_om, stride_ok, + stride_deltab, stride_deltah, stride_deltam, + stride_descale_do_z, + cu_seqlens_q, max_seqlen_q_final, + descale_do, + BLOCK_M=PRE_BLOCK, + HEAD_DIM=HEAD_DIM, + ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8 + ) + + if DEBUG: + print("delta:", delta, delta.shape) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + (0, 0 , 0 , 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + device=q.device, + dtype=torch.float32 + ) + + if DROPOUT_USE_PYTORCH: + if not IS_VARLEN: + dropout_mask = create_dropout_mask( + dropout_p, + (batch, nheads_q, max_seqlen_q_final, max_seqlen_k_final), + seed = philox_seed + ) + else: + dropout_mask = create_dropout_mask_varlen( + dropout_p, batch, nheads_q, + cu_seqlens_q, cu_seqlens_k, philox_seed + ) + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = \ + dropout_mask.stride() + + grid_dkdv = ((max_seqlen_k_final + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q_final + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + if causal: + if DEBUG_TRITON: print(f"_bwd_kernel_dkdv: grid = {grid_dkdv}, block_size = ({BLOCK_M1, BLOCK_N1})", ) # noqa: E701 + _bwd_kernel_dkdv_causal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + if DEBUG_TRITON: print(f"\n_bwd_kernel_dq: grid = {grid_dq}, block_size = ({BLOCK_M2, BLOCK_N2})", ) # noqa: E701 + _bwd_kernel_dq_causal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + _bwd_kernel_dkdv_noncausal[grid_dkdv]( + q, k, v, sm_scale, do, dk, dv, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dkb, stride_dkh, stride_dkn, stride_dkk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M1, BLOCK_N1, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + _bwd_kernel_dq_noncausal[grid_dq]( + q, k, v, sm_scale, do, dq, + softmax_lse, delta, + stride_qb, stride_qh, stride_qm, stride_qk, + stride_kb, stride_kh, stride_kn, stride_kk, + stride_vb, stride_vh, stride_vn, stride_vk, + stride_dqb, stride_dqh, stride_dqm, stride_dqk, + stride_deltab, stride_deltah, stride_deltam, + stride_dob, stride_doh, stride_dom, stride_dok, + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn, + stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_do_z, + stride_az, stride_ah, + nheads_q, nheads_k, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q_final, max_seqlen_k_final, + dropout_mask, dropout_p, philox_seed, philox_offset, + alibi_slopes, + descale_q, descale_k, descale_v, descale_do, + BLOCK_M2, BLOCK_N2, BLK_SLICE_FACTOR, + HEAD_DIM, ACTUAL_HEAD_DIM, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_OUTPUT=FP8_OUTPUT, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu = WAVES_PER_EU, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/bwd_ref.py b/flash_attn/flash_attn_triton_amd/bwd_ref.py index 7ea7c32bf7..90a98ce4fc 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/bwd_ref.py @@ -1,11 +1,14 @@ import torch import math -from .utils import DEBUG +from typing import Literal, Optional +from .utils import DEBUG, compute_alibi_tensor_ref + +DEBUG_CORE = False def attention_backward_core_ref_impl( - do, q, k, v, o, softmax_lse, sm_scale, causal, use_exp2 + do, q, k, v, o, softmax_lse, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ): - if DEBUG: + if DEBUG_CORE: print() print("attention_backward_core_ref_impl") print("do:", do, do.shape) @@ -16,6 +19,9 @@ def attention_backward_core_ref_impl( print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) # cast to float32 @@ -28,15 +34,27 @@ def attention_backward_core_ref_impl( # recompute attention_scores. Make sure it matches the forward impl. i.e. It use float32 - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) - if DEBUG: + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) # scale scores attention_scaled_scores = sm_scale * attention_scores - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + if alibi_slopes is not None: + L_q, L_k = q.shape[1], k.shape[1] + if DEBUG_CORE: + print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) + alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) + alibi_bias = alibi_bias.reshape(-1, L_q, L_k) + if True: + print("alibi_bias:", alibi_bias, alibi_bias.shape) + attention_scaled_scores = attention_scaled_scores + alibi_bias + if DEBUG_CORE: + print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) + # Apply causal mask if necessary if causal: L_q, L_k = q.shape[1], k.shape[1] @@ -44,13 +62,13 @@ def attention_backward_core_ref_impl( col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) col_offset = L_q-L_k causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG: + if DEBUG_CORE: print("causal_mask:", causal_mask) # set -inf to places the causal mask is false attention_scaled_scores = attention_scaled_scores.masked_fill( torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') ) - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) # compute probabilities using softmax_lse @@ -63,58 +81,79 @@ def attention_backward_core_ref_impl( else: softmax_lse_3d = softmax_lse.unsqueeze(-1) p = torch.exp(attention_scaled_scores - softmax_lse_3d) - - if DEBUG: + if DEBUG_CORE: print("softmax_lse_3d:", softmax_lse_3d, softmax_lse_3d.shape) print("p:", p, p.shape) - # compute gradient wrt v - dv = torch.matmul(p.transpose(-2, -1), do.to(torch.float32)) - if DEBUG: - print("dv:", dv, dv.shape) - # compute dp - dp = torch.matmul(do, v.transpose(-2, -1)) - if DEBUG: - print("dp:", dp, dp.shape) - - # calculate ds using dp - if True: - delta = torch.sum(o * do, axis=-1).to(torch.float32) # what OAI kernel uses - delta_3d = delta.unsqueeze(-1) + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + + p_drop = torch.where(dropout_mask, p, torch.zeros_like(p)) + p_drop_scaled = p_drop * dropout_scale + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("p_drop:", p_drop, p_drop.shape) + print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape) + + # compute dv + dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp_dropout = torch.matmul(do, v.transpose(-2, -1)) + dp = torch.where(dropout_mask, dp_dropout , torch.zeros_like(dp_dropout)) * dropout_scale + if DEBUG_CORE: + print("dp_dropout:", dp_dropout, dp_dropout.shape) + print("dp:", dp, dp.shape) else: - delta = torch.sum(p * dp, axis=-1) # what the math says you should use - delta_3d = delta.unsqueeze(-1) - if DEBUG: - print("delta_3d:", delta_3d, delta_3d.shape) - ds = (p * (dp - delta_3d)) * sm_scale + # compute dv + dv = torch.matmul(p.transpose(-2, -1), do) + if DEBUG_CORE: + print("dv:", dv, dv.shape) + + # compute dp + dp = torch.matmul(do, v.transpose(-2, -1)) + if DEBUG_CORE: + print("dp:", dp, dp.shape) + + # calculate ds + if False: + delta = torch.sum(o * do, axis=-1).unsqueeze(-1) + else: + delta = torch.sum(p * dp, axis=-1).unsqueeze(-1) if DEBUG: + print("delta:", delta, delta.shape) + dscores_scaled = p * (dp - delta) + ds = dscores_scaled * sm_scale + if DEBUG_CORE: + print("dscores_scaled:", dscores_scaled, dscores_scaled.shape) print("ds:", ds, ds.shape) - - # compute gradient wrt k - dk = torch.matmul(ds.transpose(-2, -1), q.to(torch.float32)) - if DEBUG: + # compute gradient wrt k & q + dk = torch.matmul(ds.transpose(-2, -1), q) + dq = torch.matmul(ds, k) + if DEBUG_CORE: print("dk:", dk, dk.shape) - - # compute gradient wrt q - dq = torch.matmul(ds, k.to(torch.float32)) - if DEBUG: print("dq:", dq, dq.shape) # cast back to original dtype dq = dq.to(torch.float16) dk = dk.to(torch.float16) dv = dv.to(torch.float16) - # remove d dim with size 1 - delta = delta_3d.squeeze(-1) + delta = delta.squeeze(-1) - if DEBUG: + if DEBUG_CORE: print("attention_backward_core_ref_impl output") - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) return dq, dk, dv, delta @@ -132,6 +171,10 @@ def attention_varlen_backward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ): # Ensure the layout is 'thd' @@ -139,8 +182,12 @@ def attention_varlen_backward_pytorch_ref_impl( raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") batch_size = cu_seqlens_q.shape[0] - 1 - num_heads = q.shape[1] - head_dim = q.shape[2] + nheads_q, head_dim = q.shape[1], q.shape[2] + nheads_k = k.shape[1] + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") # Pre-allocate outputs total_L_q = q.shape[0] @@ -149,8 +196,8 @@ def attention_varlen_backward_pytorch_ref_impl( dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) - # delta has the same shape as softmax_lse: [total_L_q, num_heads] - delta = torch.zeros((total_L_q, num_heads), dtype=torch.float32, device=o.device) + # delta has the same shape as softmax_lse: [total_L_q, nheads_q] + delta = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=o.device) for i in range(batch_size): # Get the start and end indices for the current sequence @@ -160,22 +207,41 @@ def attention_varlen_backward_pytorch_ref_impl( end_k = cu_seqlens_k[i + 1].item() # Extract q_i, k_i, v_i, do_i, o_i, softmax_lse_i - q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - do_i = do[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - o_i = o[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - # softmax_lse has shape [total_L_q, num_heads] - softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, num_heads] - softmax_lse_i = softmax_lse_i.transpose(0, 1) # [num_heads, L_q_i] - - # Permute to [num_heads, L_q_i, head_dim] + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + do_i = do[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + o_i = o[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse[start_q:end_q, :] # [L_q_i, nheads_q] + + if group_size != 1: + # MQA or GQA case + # Reshape tensors to include group dimension + q_i = q_i.view(q_i.shape[0], nheads_k, group_size, head_dim) + do_i = do_i.view(do_i.shape[0], nheads_k, group_size, head_dim) + o_i = o_i.view(o_i.shape[0], nheads_k, group_size, head_dim) + softmax_lse_i = softmax_lse_i.view(softmax_lse_i.shape[0], nheads_k, group_size) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(2).expand(-1, -1, group_size, -1) + v_i = v_i.unsqueeze(2).expand(-1, -1, group_size, -1) + # Flatten the nheads_k and group_size dimensions + q_i = q_i.reshape(q_i.shape[0], nheads_k * group_size, head_dim) + do_i = do_i.reshape(do_i.shape[0], nheads_k * group_size, head_dim) + o_i = o_i.reshape(o_i.shape[0], nheads_k * group_size, head_dim) + softmax_lse_i = softmax_lse_i.reshape(softmax_lse_i.shape[0], nheads_k * group_size) + k_i = k_i.reshape(k_i.shape[0], nheads_k * group_size, head_dim) + v_i = v_i.reshape(v_i.shape[0], nheads_k * group_size, head_dim) + # Permute to [nheads_total, L, head_dim] q_i = q_i.permute(1, 0, 2) k_i = k_i.permute(1, 0, 2) v_i = v_i.permute(1, 0, 2) do_i = do_i.permute(1, 0, 2) o_i = o_i.permute(1, 0, 2) - # softmax_lse_i is already in [num_heads, L_q_i] + softmax_lse_i = softmax_lse_i.transpose(0, 1) + if alibi_slopes is not None: + alibi_slopes_i = alibi_slopes[i] + else: + alibi_slopes_i = None # Call the core backward function for this sequence dq_i, dk_i, dv_i, delta_i = attention_backward_core_ref_impl( @@ -187,20 +253,39 @@ def attention_varlen_backward_pytorch_ref_impl( softmax_lse_i, sm_scale, causal, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes_i, use_exp2 ) # Convert back to 'thd' layout - dq_i = dq_i.permute(1, 0, 2) # [L_q_i, num_heads, head_dim] - dk_i = dk_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] - dv_i = dv_i.permute(1, 0, 2) # [L_k_i, num_heads, head_dim] + dq_i = dq_i.permute(1, 0, 2) # [L_q_i, nheads_total, head_dim] + dk_i = dk_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + dv_i = dv_i.permute(1, 0, 2) # [L_k_i, nheads_total, head_dim] + delta_i = delta_i.transpose(1, 0) # [L_q_i, nheads_total] + + if group_size != 1: + # Reshape dq_i and delta_i back to original shape + dq_i = dq_i.view(dq_i.shape[0], nheads_k, group_size, head_dim) + delta_i = delta_i.view(delta_i.shape[0], nheads_k, group_size) + # Sum dk_i and dv_i over group dimension + dk_i = dk_i.view(dk_i.shape[0], nheads_k, group_size, head_dim) + dv_i = dv_i.view(dv_i.shape[0], nheads_k, group_size, head_dim) + dk_i = dk_i.sum(dim=2) + dv_i = dv_i.sum(dim=2) + # Reshape dq_i back to [L_q_i, nheads_q, head_dim] + dq_i = dq_i.reshape(dq_i.shape[0], nheads_q, head_dim) + delta_i = delta_i.reshape(delta_i.shape[0], nheads_q) + else: + # No need to reshape + pass # Place outputs in pre-allocated tensors dq[start_q:end_q, :, :] = dq_i dk[start_k:end_k, :, :] += dk_i # Accumulate gradients for shared keys dv[start_k:end_k, :, :] += dv_i # Accumulate gradients for shared values - # delta_i has shape [num_heads, L_q_i] - delta_i = delta_i.transpose(1, 0) # [L_q_i, num_heads] delta[start_q:end_q, :] = delta_i return dq, dk, dv, delta @@ -215,6 +300,10 @@ def attention_vanilla_backward_pytorch_ref_impl( sm_scale, causal, layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ): if layout == "bshd": @@ -231,18 +320,42 @@ def attention_vanilla_backward_pytorch_ref_impl( else: raise ValueError(f"Unknown layout {layout}") - # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format - batch_size, num_heads, seq_len_q, head_dim = q.shape - seq_len_k = k.shape[2] - - # Merge batch and heads dimensions - do = do.reshape(batch_size * num_heads, seq_len_q, head_dim) - q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) - k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) - v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) - softmax_lse = softmax_lse.reshape(batch_size * num_heads, seq_len_q) - o = o.reshape(batch_size * num_heads, seq_len_q, head_dim) - + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape do, q, o to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + do = do.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape softmax_lse to [batch_size, nheads_k, group_size, seq_len_q] + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) # [batch_size, nheads_k, group_size, seq_len_k, head_dim] + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + do = do.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_k * group_size, seq_len_q) + else: + # Standard case + do = do.reshape(batch_size * nheads_q, seq_len_q, head_dim) + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) + o = o.reshape(batch_size * nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size * nheads_q, seq_len_q) + + # Call the core backward function dq, dk, dv, delta = attention_backward_core_ref_impl( do, q, @@ -252,14 +365,32 @@ def attention_vanilla_backward_pytorch_ref_impl( softmax_lse, sm_scale, causal, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2 ) - # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] - dq = dq.reshape(batch_size, num_heads, seq_len_q, head_dim) - dk = dk.reshape(batch_size, num_heads, seq_len_k, head_dim) - dv = dv.reshape(batch_size, num_heads, seq_len_k, head_dim) - delta = delta.reshape(batch_size, num_heads, seq_len_q) + if group_size != 1: + # Reshape dq back to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Reshape delta back to [batch_size, nheads_k, group_size, seq_len_q] + delta = delta.reshape(batch_size, nheads_k, group_size, seq_len_q) + # Sum dk and dv over group_size dimension, since k and v are shared across groups + dk = dk.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dk = dk.sum(dim=2) # Sum over group_size dimension + dv = dv.reshape(batch_size, nheads_k, group_size, seq_len_k, head_dim) + dv = dv.sum(dim=2) + # Reshape dq to [batch_size, nheads_q, seq_len_q, head_dim] + dq = dq.reshape(batch_size, nheads_k * group_size, seq_len_q, head_dim) + delta = delta.reshape(batch_size, nheads_k * group_size, seq_len_q) + else: + # Standard case + dq = dq.reshape(batch_size, nheads_q, seq_len_q, head_dim) + dk = dk.reshape(batch_size, nheads_k, seq_len_k, head_dim) + dv = dv.reshape(batch_size, nheads_k, seq_len_k, head_dim) + delta = delta.reshape(batch_size, nheads_q, seq_len_q) # Go back to original layout if layout == "bshd": @@ -276,25 +407,31 @@ def attention_vanilla_backward_pytorch_ref_impl( return dq, dk, dv, delta - def attention_backward_pytorch_ref_impl( - do, - q, - k, - v, - o, - softmax_lse, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2 + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool ): if layout == "thd": - dq, dk, dv, delta = attention_varlen_backward_pytorch_ref_impl( + dq_ref, dk_ref, dv_ref, delta = attention_varlen_backward_pytorch_ref_impl( do, q, k, @@ -308,10 +445,14 @@ def attention_backward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) else: - dq, dk, dv, delta = attention_vanilla_backward_pytorch_ref_impl( + dq_ref, dk_ref, dv_ref, delta = attention_vanilla_backward_pytorch_ref_impl( do, q, k, @@ -321,8 +462,17 @@ def attention_backward_pytorch_ref_impl( sm_scale, causal, layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) - return dq, dk, dv, delta + # copy into output tensor + dv.copy_(dv_ref.to(dv.dtype)) + dk.copy_(dk_ref.to(dk.dtype)) + dq.copy_(dq_ref.to(dq.dtype)) + + return delta diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py new file mode 100644 index 0000000000..df79c7926b --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/fp8.py @@ -0,0 +1,716 @@ +from typing import Optional, Sequence, Tuple, Union +import torch +import torch.nn as nn +from .utils import cast_to_fp8, is_fp8 +from . import interface_fa as flash_attn_gpu + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +class FlashAttnFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + alibi_slopes, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty + + # check output type + assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) + dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") + dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None + dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None + dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dq, + dk, + dv, + ctx.alibi_slopes, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, # gen_ + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +def flash_attn_fp8_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None +): + return FlashAttnFP8Func.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + descale_q, + descale_k, + descale_v, + descale_do + ) + +class FlashAttnVarlenFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + block_table, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + cu_seqlens_q, + cu_seqlens_k, + None, + None, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty + + # check output type + assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout_padded): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) + dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) + dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None + dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None + dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.varlen_bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.alibi_slopes, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + False, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_fp8_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None +): + return FlashAttnVarlenFP8Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + block_table, + torch.is_grad_enabled() + ) + +class FlashAttnQKVPackedFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and qkv.requires_grad + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() + head_size_og = q.size(3) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + alibi_slopes, + dropout_p, + softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + return_softmax=return_softmax and dropout_p > 0, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o, + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) + head_size_og = dout.size(3) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") + dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None + + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ctx.alibi_slopes, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, # gen_ + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + None, + None, + None, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # <=0.0 means deactivate + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + return FlashAttnQKVPackedFP8Func.apply( + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + + +class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None + ): + is_grad = is_grad_enabled and qkv.requires_grad + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + + # figure out fwd parameters + if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" + q_fp8 = q + k_fp8 = k + v_fp8 = v + out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None + + q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] + _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( + q_fp8, + k_fp8, + v_fp8, + out_fp8, + cu_seqlens, + cu_seqlens, + None, + None, + None, + alibi_slopes, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_o=descale_o + ) + if is_grad: + ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + out = out_fp8[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors + qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + # figure out bwd parameters + if is_fp8(dout_padded): # fp8 input and output + raise ValueError("fp8 input and out not supported yet for this function.") + assert (descale_do is not None), f"You need to pass descale factors for do" + dout_padded_fp8 = dout_padded + dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) + else: # cast to fp8 and return output in the fp32. (accumulator type) + assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." + dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) + dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None + + # dq, dk, dv are allocated by us so they should already be contiguous + dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] + flash_attn_gpu.varlen_bwd( + dout_padded_fp8, + q_fp8, + k_fp8, + v_fp8, + out_fp8, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.alibi_slopes, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + False, + ctx.causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + None, + rng_state, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + None, + None, + None, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_qkvpacked_fp8_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + return FlashAttnVarlenQKVPackedFP8Func.apply( + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index b37308be49..3f2d92c22d 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -1,16 +1,75 @@ import torch import triton import triton.language as tl -from .utils import _strides, get_padded_headsize - +from typing import Literal, Optional, Union +from .utils import AUTOTUNE, DEBUG, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna + +def get_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + +def get_autotune_configs(): + if AUTOTUNE: + if is_cdna(): + autotune_configs, autotune_keys = get_cdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + else: + raise ValueError("Unknown Device Type") + else: + autotune_configs, autotune_keys = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + fwd_auto_tune_configs, fwd_autotune_keys= autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) + + +(fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() + +# @triton.autotune( +# configs=fwd_auto_tune_configs, +# key=fwd_autotune_keys, +# use_cuda_graph=True, +# ) @triton.jit def _fwd_kernel_splitK( Q, K, V, sm_scale, - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] K_new, V_new, Cache_seqlens, @@ -70,62 +129,91 @@ def _fwd_kernel_splitK( IS_GQA: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_ALIBI: tl.constexpr, + PADDED_HEAD: tl.constexpr, + GROUP_SIZE: tl.constexpr, ): - # Padding - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) - if PADDED_HEAD: - d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL - - start_m = tl.program_id(0) - off_zhg = tl.program_id(1) - off_z = off_zhg // (H_q * G_q) - off_h_q = (off_zhg // G_q) % H_q - off_g_q = off_zhg % G_q - splitk_idx = tl.program_id(2) + # get program ids + pid_m = tl.program_id(0) + pid_zhg = tl.program_id(1) + pid_splitk = tl.program_id(2) - # pick batch index - if USE_CACHE_BATCH_IDX: - cache_batch_idx = tl.load(Cache_batch_idx + off_z) - else: - cache_batch_idx = off_z + # compute z, h and g ids + z_id = pid_zhg // (H_q * G_q) + hq_id = (pid_zhg // G_q) % H_q + g_id = pid_zhg % G_q - # Load ALiBi slope if enabled - if USE_ALIBI: - a_offset = off_z * stride_az + off_h_q * stride_ah - alibi_slope = tl.load(Alibi_slopes + a_offset) + # is gqa + if IS_GQA: + hk_id = hq_id // GROUP_SIZE + hv_id = hk_id else: - alibi_slope = None + hk_id = hq_id + hv_id = hq_id - lo = splitk_idx * BLOCK_N_PER_SPLIT + # figure out seqlens + lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: - cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) + cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) if NEW_KV: - kv_len = cache_seqlen_last_idx + N_CTX_NEW + N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW else: - kv_len = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: - kv_len = N_CTX_K - hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + N_CTX_K_FINAL = N_CTX_K + hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) - HEAD_RATIO: tl.constexpr = H_q // H_kv - if IS_GQA: - k_head_idx = off_h_q // HEAD_RATIO - v_head_idx = k_head_idx + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + z_id) + else: + cache_batch_idx = z_id + + # compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # compute ptrs + q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # compute masks + if PADDED_HEAD: + q_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[None, :] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + osk_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] else: - k_head_idx = off_h_q - v_head_idx = off_h_q + q_mask = (offs_m < N_CTX_Q)[:, None] + kT_mask = (offs_n < N_CTX_K_FINAL)[None, :] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] + osk_mask = (offs_m < N_CTX_Q)[:, None] - # calculate base offset - k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg - v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q = (q * qk_scale).to(q.dtype) + + # load ALiBi slope if enabled + if USE_ALIBI: + a_offset = z_id * stride_az + hq_id * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None # Copy new Keys and Values into Cache if NEW_KV: - knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g + knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g # Determine the starting position for new data in the cache if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + off_z) + start_idx = tl.load(Cache_seqlens + z_id) else: start_idx = N_CTX_K - N_CTX_NEW @@ -143,7 +231,7 @@ def _fwd_kernel_splitK( # Store to K tl.store( - k_base + + k_offset + tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, k_new_block, @@ -152,7 +240,7 @@ def _fwd_kernel_splitK( ) # Copy new Values - vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g + vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g for i in range(0, N_CTX_NEW, BLOCK_N): # Load from V_new v_new_block = tl.load( @@ -166,7 +254,7 @@ def _fwd_kernel_splitK( # Store to V tl.store( - v_base + + v_offset + (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, v_new_block, @@ -174,34 +262,6 @@ def _fwd_kernel_splitK( (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), ) - Q_block_ptr = tl.make_block_ptr( - base=Q + off_h_q * stride_qh + off_z * stride_qz + off_g_q * stride_qg, - shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qd), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - - K_block_ptr = tl.make_block_ptr( - base=k_base, - shape=(ACTUAL_BLOCK_DMODEL, hi), - strides=(stride_kd, stride_kn), - offsets=(0, lo), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - V_block_ptr = tl.make_block_ptr( - base=v_base, - shape=(hi, ACTUAL_BLOCK_DMODEL), - strides=(stride_vn, stride_vd), - offsets=(lo, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - - K_scale_shift_block_ptr = None - V_scale_shift_block_ptr = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) @@ -209,45 +269,26 @@ def _fwd_kernel_splitK( acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout - q = tl.load( # noqa: F821 - tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) - q = (q * qk_scale).to(q.dtype) - if PADDED_HEAD: - q = tl.where(d_mask[None, :], q, 0.0) # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): - k, v = load_k_v_group( - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N, - 1, - BLOCK_DMODEL, - ACTUAL_BLOCK_DMODEL, - Q.dtype.element_ty, - 0, - ) - if PADDED_HEAD: - k = tl.where(d_mask[:, None], k, 0.0) - v = tl.where(d_mask[None, :], v, 0.0) + kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn + V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) # noqa: F821 + qk += tl.dot(q, kT) # noqa: F821 if USE_ALIBI: - row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_idx = start_n + tl.arange(0, BLOCK_N) # Compute relative positions - relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) relative_pos = tl.abs(relative_pos) # Compute ALiBi bias @@ -256,11 +297,11 @@ def _fwd_kernel_splitK( # Apply causal mask if IS_CAUSAL is True if IS_CAUSAL: - row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) col_idx = start_n + tl.arange(0, BLOCK_N) # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - kv_len + col_offset = N_CTX_Q - N_CTX_K_FINAL causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) # Apply the mask @@ -293,101 +334,34 @@ def _fwd_kernel_splitK( # -- scale and update acc -- acc *= alpha[:, None] acc += tl.dot(p.to(v.dtype), v) - - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) # write back O - O_block_ptr = tl.make_block_ptr( - base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, - shape=(N_CTX_Q, BLOCK_DMODEL), - strides=(stride_osk_m, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s + osk_ptrs = osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d tl.store( - tl.advance(O_block_ptr, (0, 0)), + osk_ptrs, acc, - boundary_check=(0, ), + mask=osk_mask, ) - # Write metadata for split-K reduction - Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + - tl.arange(0, BLOCK_M)) - tl.store(Metadata_ptr, m_i) - tl.store(Metadata_ptr + stride_m2, l_i) - - -@triton.jit -def load_k_v_group( - K_block_ptr, - V_block_ptr, - K_scale_shift_block_ptr, - V_scale_shift_block_ptr, - BOUNDS_CHECKS_N: tl.constexpr, - PACKED_PER_VAL: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - dtype: tl.constexpr, - group_id: tl.constexpr, -): - #Load K/V for a given block - - # Advance to the current quantization group - K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) - V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) - - # -- load k, v -- - k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) - v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) - - return k, v - - -@triton.jit -def cast_uint32_to_half2(scale_shift): - # Extract two float16 packed into one int32 - scale = scale_shift & 0xFFFF - shift = scale_shift >> 16 - scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) - shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) - return scale, shift - - -@triton.jit -def dequantize( - x_, - scale, - shift, - PACKED_PER_VAL: tl.constexpr = 8, -): - # PACKED_PER_VAL is the number of values packed into - # each element x_. For example, for int4 quantization - #and x_ of type int32, PACKED_PER_VAL is 8. - BLOCK_N: tl.constexpr = x_.shape[0] - BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] - offsets = tl.arange(0, PACKED_PER_VAL) * 4 - quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) - - quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) - # Trick - instead of converting int4 to float16 we view it as float16 - # and then multiply by 32768 * 512 == 2**24 - quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) - quant_offset = (quant_offset * 32768.0).to(tl.float16) - scale_512 = scale * 512 - - dequant = quant_offset * scale_512 + shift - return dequant + # write metadata for split-K reduction + metadata_offset = Metadata + pid_zhg * stride_mzhg + pid_splitk * stride_ms + metadata_ptr = metadata_offset + offs_m + tl.store(metadata_ptr, m_i) + tl.store(metadata_ptr + stride_m2, l_i) +# @triton.autotune( +# configs=reduce_auto_tune_configs, +# key=reduce_autotune_keys, +# use_cuda_graph=True, +# ) @triton.jit def _splitK_reduce( - Out_splitK, # [B, H, split_k, Mq, K] - Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] - Out, # [B, H, M, K] - LSE, # [B, H, M] + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, G, M, K] + LSE, # [B*H*G, M] stride_osk_zhg, stride_osk_s, stride_osk_m, @@ -403,41 +377,50 @@ def _splitK_reduce( stride_ok, stride_lse_zhg, stride_lse_m, - M_ceil: tl.constexpr, - BLOCK_SIZE: tl.constexpr, + K_BLOCK_SIZE: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, H: tl.constexpr, G: tl.constexpr, split_k: tl.constexpr, splitK_pow2: tl.constexpr, - use_mask: tl.constexpr, + MASK_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, + PADDED_HEAD: tl.constexpr, ): - off_zhg = tl.program_id(0) - off_z = off_zhg // (H * G) - off_h = (off_zhg // G) % H - off_g = off_zhg % G - off_m = tl.program_id(1) - off_k = tl.program_id(2) + # get pids + pid_zhg = tl.program_id(0) + pid_m = tl.program_id(1) + pid_k = tl.program_id(2) - # read chunk - spk_idx = tl.arange(0, splitK_pow2) - kidx = tl.arange(0, BLOCK_SIZE) + # compute offsets + offs_splitK = tl.arange(0, splitK_pow2) + offs_k = pid_k * K_BLOCK_SIZE + tl.arange(0, K_BLOCK_SIZE) - Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) - o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + - stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) + # compute masks + if PADDED_HEAD: + o_mask = offs_k < ACTUAL_BLOCK_DMODEL + else: + o_mask = None + + # compute ptrs + metadata_offset = Metadata + pid_zhg * stride_mzhg + metadata_ptr = metadata_offset + offs_splitK * stride_ms + pid_m * stride_mm + + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_m * stride_osk_m + osk_ptr = osk_offset + offs_splitK[:, None] * stride_osk_s + offs_k[None, :] * stride_osk_k # read max values of each splitK - if use_mask: - spk_mask = spk_idx < split_k - l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) - l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) - acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) + if MASK_SPLITK: + splitK_mask = offs_splitK < split_k + l_m = tl.load(metadata_ptr, mask=splitK_mask, other=float("-inf")) + l_sum = tl.load(metadata_ptr + stride_m2, mask=splitK_mask, other=0.0) + acc = tl.load(osk_ptr, mask=splitK_mask[:, None], other=0.0) else: - l_m = tl.load(Metadata_ptr) - l_sum = tl.load(Metadata_ptr + stride_m2) - acc = tl.load(o_ptr) + l_m = tl.load(metadata_ptr) + l_sum = tl.load(metadata_ptr + stride_m2) + acc = tl.load(osk_ptr) g_m = tl.max(l_m, axis=0) @@ -460,12 +443,15 @@ def _splitK_reduce( acc_out = tl.sum(acc, axis=0) / g_sum # Store output - Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + - off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) - tl.store(Out_ptr, acc_out) + z_id = pid_zhg // (H * G) + h_id = (pid_zhg // G) % H + g_id = pid_zhg % G + out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og + out_ptr = out_offset + pid_m * stride_om + offs_k + tl.store(out_ptr, acc_out, mask=o_mask) # Store lse - l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m if IS_CAUSAL: lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) tl.store(l_ptrs, lse) @@ -473,6 +459,41 @@ def _splitK_reduce( tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = (x_[:, None, :] >> offsets[None, :, None]) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: # Scale and shift are such that quantization linearly maps # int4 values range [0..15] to input values range min(k)..max(k) @@ -540,122 +561,204 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int: split_k = max(split_k, 1) return split_k -def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new): - # kernel config +def attention_decode_forward_triton_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + sm_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_batch_idx: Optional[torch.Tensor], +): + # triton configs BLOCK_M = 16 BLOCK_N = 64 + num_stages = 1 + num_warps_fwd = 1 + num_warps_reduce = 4 + + # kernel_configs + is_new_kv = True if k_new is not None and v_new is not None else False + use_alibi = False if alibi_slopes is None else True + use_cache_seqlens = cache_seqlens is not None SPLIT_K = None NUM_QUANT_GROUPS = 1 - # kernels expects "bsghd" - original_layout = layout + # get shapes and strides + (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) + (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + if is_new_kv: + ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) + (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) + else: + ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = (None, None, None, None), (None, None, None, None) + (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = (None, None, None, None), (None, None, None, None) + (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = get_shape_and_strides_from_layout(out, layout) + if use_alibi: + stride_az, stride_ah = alibi_slopes.stride() + else: + stride_az, stride_ah = (None, None) + + assert dim_q == dim_kc == dim_vc, f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" + + # add extra information needed by the kernels if layout == "bshd": - q=q.unsqueeze(2) - k=k.unsqueeze(2) - v=v.unsqueeze(2) - if new_kv: - k_new = k_new.unsqueeze(2) - v_new = v_new.unsqueeze(2) - layout = "bsghd" - elif layout == "bhsd": - q=q.permute(0, 2, 1, 3).unsqueeze(2) - k=k.permute(0, 2, 1, 3).unsqueeze(2) - v=v.permute(0, 2, 1, 3).unsqueeze(2) - if new_kv: - k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) - v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) - layout = "bsghd" - elif layout == "bsghd": - pass - elif layout is None: - raise ValueError("Layout not given") - assert layout == "bsghd" - - # get dims - batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape - _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape - _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape - - assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" + (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm + (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n + (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n + if is_new_kv: + (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n + (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n + else: + (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None + (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None + (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om + else: + raise ValueError(f"{layout} layout is not supported") # get padded size - dim_padded = get_padded_headsize(dim_k) + dim_padded = get_padded_headsize(dim_kc) + is_padded_head = dim_padded != dim_kc # Handle MQA/GQA case - if heads_per_group_q > heads_per_group_k: + group_size = nheads_q // nheads_kc + if group_size > 1: is_gqa = True - elif heads_per_group_q < heads_per_group_k: - raise ValueError("heads_per_group_q < heads_per_group_k") else: is_gqa = False - assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" - if SPLIT_K is not None: split_k = SPLIT_K else: # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_k) # NOTE: should the split think about seqlens? + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + split_size = (seqlen_kc + split_k - 1) // split_k + # setup grid seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M - out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) + grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch_size * n_group_q * heads_per_group_q, split_k) + + # create intermediate tensors + out_splitk = torch.empty([batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], dtype=torch.float32, device=q.device) metadata = torch.empty([batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) - lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), device=q.device, dtype=torch.float32) - grid = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * n_group_q * heads_per_group_q, split_k) - - num_warps = 1 - split_size = (seqlen_k + split_k - 1) // split_k - use_cache_seqlens = cache_seqlens is not None + lse = torch.empty((batch_size * n_group_q * heads_per_group_q, seqlen_q), dtype=torch.float32, device=q.device) + + # get intermediate tensor strides + stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() + stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() + stride_lse_zhg, stride_lse_m = lse.stride() + + if False: + print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) + print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) + print("dim_padded:", dim_padded) + print("stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd)) + print("stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d)) + print("stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d)) + if is_new_kv: + print("stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d)) + print("stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d)) + print("stride_oz, stride_om, stride_og, stride_oh, stride_od", (stride_oz, stride_om, stride_og, stride_oh, stride_od)) + print("stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d)) + print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) + print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, - K=k, - V=v, + K=k_cache, + V=v_cache, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new = k_new, - V_new = v_new, + K_new=k_new, + V_new=v_new, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, Alibi_slopes=alibi_slopes, - **_strides(q, "qz", "qm", "qg", "qh", "qd"), - **_strides(k, "kz", "kn", "kg", "kh", "kd"), - **_strides(v, "vz", "vn", "vg", "vh", "vd"), - **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), - **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), - **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), - **_strides(alibi_slopes, "az", "ah"), + # q strides + stride_qz=stride_qz, + stride_qm=stride_qm, + stride_qg=stride_qg, + stride_qh=stride_qh, + stride_qd=stride_qd, + # k strides + stride_kz=stride_kc_z, + stride_kn=stride_kc_n, + stride_kg=stride_kc_g, + stride_kh=stride_kc_h, + stride_kd=stride_kc_d, + # v strides + stride_vz=stride_vc_z, + stride_vn=stride_vc_n, + stride_vg=stride_vc_g, + stride_vh=stride_vc_h, + stride_vd=stride_vc_d, + # out_splitk strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_d=stride_osk_d, + # metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # k_new strides + stride_kn_z=stride_kn_z, + stride_kn_n=stride_kn_n, + stride_kn_g=stride_kn_g, + stride_kn_h=stride_kn_h, + stride_kn_d=stride_kn_d, + # v_new strides + stride_vn_z=stride_vn_z, + stride_vn_n=stride_vn_n, + stride_vn_g=stride_vn_g, + stride_vn_h=stride_vn_h, + stride_vn_d=stride_vn_d, + # alibi strides + stride_az=stride_az, + stride_ah=stride_ah, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, - N_CTX_K=seqlen_k, - N_CTX_NEW=k_new.shape[1] if new_kv else None, + N_CTX_K=seqlen_kc, + N_CTX_NEW=seqlen_kn, BLOCK_N_PER_SPLIT=split_size, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, - ACTUAL_BLOCK_DMODEL=dim_k, + ACTUAL_BLOCK_DMODEL=dim_kc, BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=new_kv, + NEW_KV=is_new_kv, IS_GQA=is_gqa, IS_CAUSAL=causal, - USE_ALIBI=False if alibi_slopes is None else True, - num_warps=num_warps, - num_stages=1, + USE_ALIBI=use_alibi, + PADDED_HEAD=is_padded_head, + GROUP_SIZE=group_size, + num_warps=num_warps_fwd, + num_stages=num_stages, ) - out = torch.empty((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) + if DEBUG: + print("Out_splitK:", out_splitk, out_splitk.shape) + print("metadata:", metadata, metadata.shape) + print("lse:", lse, lse.shape) + print("Out:", out, out.shape) # Merge together splitK_pow2 = triton.next_power_of_2(split_k) - use_mask = splitK_pow2 > split_k + mask_split_k = splitK_pow2 > split_k if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: k_block_num = 1 else: @@ -664,40 +767,48 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes k_block_size = dim_padded // k_block_num grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + if DEBUG: + print("splitK_pow2:", splitK_pow2) + print("k_block_num:", k_block_num) + print("k_block_size:", k_block_size) + print("grid:", grid) + _splitK_reduce[grid]( out_splitk, metadata, out, lse, - **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), - **_strides(metadata, "mzhg", "m2", "ms", "mm"), - **_strides(out, "oz", "om", "og", "oh", "ok"), - **_strides(lse, "lse_zhg", "lse_m"), - M_ceil=seqlen_q_ceil, - BLOCK_SIZE=k_block_size, + # Split-K output strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_k=stride_osk_d, + # Metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # Output tensor strides + stride_oz=stride_oz, + stride_oh=stride_oh, + stride_og=stride_og, + stride_om=stride_om, + stride_ok=stride_od, + # LSE strides + stride_lse_zhg=stride_lse_zhg, + stride_lse_m=stride_lse_m, + K_BLOCK_SIZE=k_block_size, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_kc, G=n_group_q, H=heads_per_group_q, # TODO: Tune num_warps split_k=split_k, splitK_pow2=splitK_pow2, - use_mask=use_mask, + MASK_SPLITK=mask_split_k, IS_CAUSAL=causal, - num_warps=4) - - lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) - if q.ndim == 4: - # BMGHK -> BMHK - assert n_group_q == 1 - out = out[:, :, 0] - lse = lse[:, 0] - if seqlen_k == 0: - out.zero_() - out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() - - # output is batch_size, heads_per_group_q * group_q, seqlen_q, dim_q - if original_layout == "bshd": - # out=out.transpose(1, 2).contiguous() # this screws up heads and data. - # the data is laid out properly. Just need to reshape dims - out = out.reshape(batch_size, seqlen_q, -1, dim_padded) - - return out.narrow(-1, 0, dim_k), lse + PADDED_HEAD=is_padded_head, + num_warps=num_warps_reduce) + + return lse diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 2a59dc4e5d..dec5673e3e 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -1,32 +1,12 @@ import torch import triton import triton.language as tl -from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE - -@triton.jit -def cdiv_fn(x, y): - return (x + y - 1) // y - -@triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) - rng_keep = rng_output > dropout_p - return rng_keep +from typing import Literal, Optional, Union +from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_shapes_from_layout, get_strides_from_layout, is_cdna, is_fp8, is_rdna, create_dropout_mask +# NOTE: triton fails to import tl.constexprs so create them here for the file +tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) +tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) # Convenience function to load with optional boundary checks. # "First" is the major dim, "second" is the minor dim. @@ -46,49 +26,16 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): tensor = tl.load(ptrs) return tensor - -@triton.jit -def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): - # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix - # for casual mask we want something like this where (1 is kept and 0 is masked) - # seqlen_q = 2 and seqlen_k = 5 - # 1 1 1 1 0 - # 1 1 1 1 1 - # seqlen_q = 5 and seqlen_k = 2 - # 0 0 - # 0 0 - # 0 0 - # 1 0 - # 1 1 - # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal - # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False - # 1. offs_m[:,None] = [[0], - # [1], - # 2. offs_m[:,None] + seqlen_k = [[5], - # [6], - # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], - # [4], - # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], - # [4], [ 4, 3, 2, 1, 0]] - # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], - # [ -4, -3, -2, -1, 0]], - relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] - alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) - if transpose: - return alibi_block.T - else: - return alibi_block - - @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, exp_scores_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr, - RETURN_SCORES: tl.constexpr): + ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 @@ -105,7 +52,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. @@ -120,13 +67,18 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - + + # compute masks + q_mask = (OFFS_M[:, None] < actual_seqlen_q) + k_mask = ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) + p_mask = q_mask & k_mask + # -- compute qk ---- - qk += tl.dot(q, k) + if IS_FP8 : + qk += (tl.dot(q, k) * descale_q * descale_k) + else: + qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE - if RETURN_SCORES: - score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(score_ptrs, qk_scaled, mask=score_mask) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -137,8 +89,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) qk_scaled += bias - if alibi_slope is not None: - # Compute the global position of each token within the sequence + if USE_ALIBI: + # compute the global position of each token within the sequence global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, @@ -149,10 +101,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # scale and subtract max q_shifted = qk_scaled - m_ij[:, None] - if RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - scores_scaled_shifted_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(scores_scaled_shifted_ptrs, q_shifted, mask=scores_scaled_shifted_mask) # Compute scaled QK and softmax probabilities if USE_EXP2: @@ -163,17 +111,23 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) - if RETURN_SCORES: - # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(exp_scores_ptrs, tl.where(keep, p, -p), mask=exp_score_mask) - p = tl.where(keep, p, 0.0) + if tl_DROPOUT_USE_PYTORCH: + dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask) + else: + rng_output = tl.rand(philox_seed, philox_ptrs) # TODO: use tl.randint for better performance + dropout_mask = rng_output > dropout_p + if tl_DROPOUT_DUMP: + tl.store(dropout_mask_ptrs, dropout_mask, mask=p_mask) + + # return scores with negative values for dropped vals + sd_mask = tl.where(dropout_mask, p, -p) + tl.store(sd_mask_ptrs, sd_mask, mask=p_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) elif RETURN_SCORES: # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k) - tl.store(exp_scores_ptrs, p, mask=exp_score_mask) + tl.store(sd_mask_ptrs, p, mask=p_mask) # -- update output accumulator -- # alpha is an adjustment factor for acc and li as we loop and find new maxes @@ -190,15 +144,23 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(v.type.element_ty), v) + + if IS_FP8: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + else: + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: bias_ptrs += BLOCK_N * stride_bn if RETURN_SCORES: - score_ptrs += BLOCK_N - scores_scaled_shifted_ptrs += BLOCK_N - exp_scores_ptrs += BLOCK_N + sd_mask_ptrs += BLOCK_N * stride_sn + + if ENABLE_DROPOUT: + dropout_mask_ptrs += BLOCK_N * stride_sn + philox_ptrs += BLOCK_N * stride_sn return acc, l_i, m_i @@ -219,7 +181,7 @@ def get_cdna_autotune_configs(): # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_rdna_autotune_configs(): @@ -239,7 +201,7 @@ def get_rdna_autotune_configs(): # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] def get_autotune_configs(): @@ -263,7 +225,7 @@ def get_autotune_configs(): "MAX_SEQLENS_Q", "MAX_SEQLENS_K", "ACTUAL_BLOCK_DMODEL", - "VARLEN", + "IS_VARLEN", "HQ", "HK", ] @@ -277,34 +239,46 @@ def get_autotune_configs(): use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, +def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, + Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, scores, scores_scaled_shifted, exp_scores, alibi_slopes, HQ: tl.constexpr, + dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_INFERENCE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr): + ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + # set params + ACCUMULATOR_TYPE = tl.float32 + + # compute offsets start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) - if VARLEN: + + # handle seqlen + if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - # print("cu_seqlens_q_start:", cu_seqlens_q_start) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. + + # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + elif IS_INFERENCE: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = tl.load(Cache_seqlens + off_z) else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -317,14 +291,14 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # inf written to LSE. We don't need to do any GEMMs in this case. # This block of code determines what N is, and if this WG is operating # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + n_blocks = tl.cdiv(seqlen_k, BLOCK_N) if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + n_blocks_seqlen = tl.cdiv((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) @@ -341,9 +315,9 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # statically known. l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m - - l = tl.full([BLOCK_M], value=0.0, dtype=tl.float32) - + + l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) + # mask_m_offsets = start_m + tl.arange(0, BLOCK_M) # lse_mask = mask_m_offsets < causal_start_idx # softmax_lse = tl.where(lse_mask, 0.0, softmax_lse) @@ -391,34 +365,37 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ alibi_slope = None if RETURN_SCORES: - scores_offset = scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - score_ptrs = scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - - scores_scaled_shifted_offset = scores_scaled_shifted + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - scores_scaled_shifted_ptrs = scores_scaled_shifted_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn - - exp_scores_offset = exp_scores + off_z * stride_sz + off_h_q * stride_sh + cu_seqlens_q_start * stride_sm - exp_scores_ptrs = exp_scores_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + sd_mask_offset = sd_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + sd_mask_ptrs = sd_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: - score_ptrs = None - scores_scaled_shifted_ptrs = None - exp_scores_ptrs = None + sd_mask_ptrs = None if ENABLE_DROPOUT: - off_hz = off_z * HQ + off_h_q - batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + dropout_mask_offset = dropout_mask + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + dropout_mask_ptrs = dropout_mask_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn + batch_philox_offset = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm + philox_ptrs = batch_philox_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn else: - batch_philox_offset = 0 + dropout_mask_ptrs = None + philox_ptrs = 0 # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) + l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + # Load scale factors if IS_FP8. + if IS_FP8: + descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) + descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) + descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) @@ -439,16 +416,17 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - exp_scores_ptrs, + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + block_min, block_max, 0, 0, 0, alibi_slope, + descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, # IS_CAUSAL, .... False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, False, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) block_min = block_max block_max = n_blocks * BLOCK_N @@ -464,23 +442,25 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ if USE_BIAS: bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_SCORES: - score_ptrs += n_full_blocks * BLOCK_N - scores_scaled_shifted_ptrs += n_full_blocks * BLOCK_N - exp_scores_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - exp_scores_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, score_ptrs, scores_scaled_shifted_ptrs, + sd_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + if ENABLE_DROPOUT: + dropout_mask_ptrs += n_full_blocks * BLOCK_N * stride_sn + philox_ptrs += n_full_blocks * BLOCK_N * stride_sn + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, + sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES) + ACTUAL_BLOCK_DMODEL, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. l_recip = 1 / l_i[:, None] acc = acc * l_recip if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, # then we have one block with a row of all NaNs which come from computing # softmax over a row of all -infs (-inf - inf = NaN). We check for that here @@ -488,7 +468,6 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) @@ -496,7 +475,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - + # write back LSE(Log Sum Exponents), the log of the normalization constant l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m l_ptrs = l_offset + offs_m * stride_lse_m @@ -510,7 +489,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ softmax_lse *= LN2 else: softmax_lse = m_i + tl.math.log(l_i) - + if IS_CAUSAL: # zero out nans caused by -infs when doing causal lse_mask = (start_m_idx + tl.arange(0, BLOCK_M)) < causal_start_idx @@ -534,55 +513,83 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_ o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) if PADDED_HEAD: o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + if FP8_OUTPUT: + scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) + tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) + tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) + else: + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( - q, - k, - v, - o, - sm_scale, - alibi_slopes, - causal, - bias, - dropout_p, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - return_scores, - use_exp2): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # inference + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + cache_batch_idx: Optional[torch.Tensor], + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_softmax: bool, + use_exp2: bool, + # fp8 + descale_q: Optional[torch.Tensor], + descale_k: Optional[torch.Tensor], + descale_v: Optional[torch.Tensor], + descale_o: Optional[torch.Tensor], +): + IS_FP8 = is_fp8(q) + if IS_FP8: + FP8_MAX: tl.constexpr = torch.finfo(q.dtype).max + + assert is_fp8(q) and is_fp8(k) and is_fp8(v), f"Non fp8 type found: q.dtype={q.dtype}, k.dtype={k.dtype}, v.dtype={v.dtype}. All tensors must be fp8." + + if is_fp8(o): + FP8_OUTPUT = True + assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor for the output." + else: + FP8_OUTPUT = False - if DEBUG: - print() - print("attention_prefill_forward_triton_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("sm_scale:", sm_scale) - print("alibi_slopes:", alibi_slopes) - print("causal:", causal) - print("bias:", bias) - print("dropout_p:", dropout_p) - print("layout:", layout) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlens_q:", max_seqlens_q) - print("max_seqlens_k:", max_seqlens_k) - print("return_scores:", return_scores) - print("use_exp2:", use_exp2) - - # check if varlen + # Get strides for the kernel + stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None + stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None + stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None + stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + else: + FP8_MAX = None + FP8_OUTPUT = False + descale_q = descale_k = descale_v = descale_o = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None + + # check flags is_varlen = layout == "thd" + use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + is_inference = False if cache_seqlens is None else True + if is_inference: + assert layout == "bshd", f"{layout} layout is not supported with inference. Use bshd layout" + if DEBUG: + print(f"is_inference:", is_inference) # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (bias is not None): assert (bias.numel() < 2**31) - batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shapes_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k) q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) # Get closest power of 2 over or equal to 32. @@ -593,60 +600,49 @@ def attention_prefill_forward_triton_impl( grid = lambda META: (triton.cdiv(max_seqlens_q, META['BLOCK_M']), nheads_q, batch) - if return_scores: - scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - scores_scaled_shifted = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, - dtype=torch.float32) - scores_strides = (scores.stride(0), scores.stride(1), scores.stride(2), scores.stride(3)) - else: - scores = None - scores_scaled_shifted = None - scores_strides = (0, 0 , 0 , 0) - - # exp_scores is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing - # only. This return holds no useful output aside from debugging. - if return_scores: - exp_scores = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + # only. This return holds no useful output aside from debugging. + use_dropout = (dropout_p > 0.0) + if use_dropout or return_softmax: + sd_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, + dtype=torch.float32) + if DROPOUT_USE_PYTORCH: + dropout_mask = create_dropout_mask(dropout_p, (batch, nheads_q, max_seqlens_q, max_seqlens_k), seed = philox_seed) + else: + dropout_mask = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), device=q.device, dtype=torch.float32) + scores_strides = (sd_mask.stride(0), sd_mask.stride(1), sd_mask.stride(2), sd_mask.stride(3)) else: - exp_scores = None + sd_mask = None + dropout_mask = None + scores_strides = (0, 0, 0, 0) # stores LSE the log of the normalization constant / sum of expoential score(unnormalzied probablities) if is_varlen: - softmax_lse = torch.empty((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros((q.shape[0], nheads_q), device=q.device, dtype=torch.float32) stride_lse_m, stride_lse_h = softmax_lse.stride() stride_lse_z = 0 else: - softmax_lse = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + softmax_lse = torch.zeros((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() - # Seed the RNG so we get reproducible results for testing. - philox_seed = 0x1BF52 - philox_offset = 0x1D4B42 - if bias is not None: bias_strides = (bias.stride(0), bias.stride(1),bias.stride(2), bias.stride(3)) else: bias_strides = (0, 0, 0, 0) - if alibi_slopes is not None: - alibi_strides = (alibi_slopes.stride(0), alibi_slopes.stride(1)) - else: - alibi_strides = (0, 0) - - - attn_fwd[grid](q, k, v, bias, sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, - *bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, - dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, scores=scores, - scores_scaled_shifted=scores_scaled_shifted, exp_scores=exp_scores, alibi_slopes=alibi_slopes, + attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx, + descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, + MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference, BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, - USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p - > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_scores) + USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p + > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) - return o, softmax_lse, exp_scores, grid, head_size, philox_seed, philox_offset, scores, scores_scaled_shifted + return softmax_lse, sd_mask if return_softmax else None diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 1cc51d17e7..baefb2410c 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -1,9 +1,12 @@ import torch import math -from .utils import DEBUG +from typing import Literal, Optional +from .utils import DEBUG, compute_alibi_tensor_ref -def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): - if DEBUG: +DEBUG_CORE = False + +def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): + if DEBUG_CORE: print() print("attention_forward_core_ref_impl") print("q:", q, q.shape) @@ -11,18 +14,42 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): print("v:", v, v.shape) print("sm_scale:", sm_scale) print("causal:", causal) + print("dropout_p:", dropout_p) + print("philox_seed:", philox_seed) + print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) + + # cast to float32 + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) # Compute attention scores - attention_scores = torch.matmul(q.to(torch.float32), k.transpose(-2, -1).to(torch.float32)) - if DEBUG: + attention_scores = torch.matmul(q, k.transpose(-2, -1)) + if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) # Scale scores attention_scaled_scores = sm_scale * attention_scores - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores:", attention_scaled_scores, attention_scaled_scores.shape) + # Apply ALiBi if slopes are provided + if alibi_slopes is not None: + L_q, L_k = q.shape[1], k.shape[1] + if DEBUG_CORE: + print("alibi_slopes:", alibi_slopes, alibi_slopes.shape) + alibi_bias = compute_alibi_tensor_ref(alibi_slopes, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias:", alibi_bias, alibi_bias.shape) + alibi_bias = alibi_bias.reshape(-1, L_q, L_k) + if DEBUG_CORE: + print("alibi_bias_flat:", alibi_bias, alibi_bias.shape) + attention_scaled_scores = attention_scaled_scores + alibi_bias + if DEBUG_CORE: + print("attention_scaled_scores after alibi:", attention_scaled_scores, attention_scaled_scores.shape) + + # Apply causal mask if necessary if causal: L_q, L_k = q.shape[1], k.shape[1] @@ -30,19 +57,18 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): col_idx = torch.arange(L_k, device=q.device).unsqueeze(0) col_offset = L_q-L_k causal_mask = row_idx >= (col_offset + col_idx) - if DEBUG: + if DEBUG_CORE: print("causal_mask:", causal_mask) # set -inf to places the causal mask is false attention_scaled_scores = attention_scaled_scores.masked_fill( torch.logical_not(causal_mask.unsqueeze(0)), float('-inf') ) - if DEBUG: + if DEBUG_CORE: print("attention_scaled_scores after causal:", attention_scaled_scores, attention_scaled_scores.shape) - # Compute max for numerical stability max_scores = torch.max(attention_scaled_scores, dim=-1, keepdim=True)[0] - if DEBUG: + if DEBUG_CORE: print("max_scores:", max_scores, max_scores.shape) if causal: # Replace -inf in max_scores with zeros to avoid NaN in subtraction @@ -54,7 +80,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): # Shift scores attention_shifted_scaled_scores = attention_scaled_scores - max_scores - if DEBUG: + if DEBUG_CORE: print("attention_shifted_scaled_scores:", attention_shifted_scaled_scores, attention_shifted_scaled_scores.shape) # Exponentiate @@ -64,12 +90,12 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): else: exp_scores = torch.exp(attention_shifted_scaled_scores) - if DEBUG: + if DEBUG_CORE: print("exp_scores:", exp_scores, exp_scores.shape) # Sum of exponentials sum_exp_scores = torch.sum(exp_scores, dim=-1, keepdim=True) - if DEBUG: + if DEBUG_CORE: print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) if causal: # if sum of exp scores is 0.0 it means scores where -inf, we cannot compute softmax and softmax_lse. Setting to 1 deals with -inf case cleanly @@ -78,15 +104,32 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): torch.ones_like(sum_exp_scores), sum_exp_scores ) - if DEBUG: + if DEBUG_CORE: print("sum_exp_scores:", sum_exp_scores, sum_exp_scores.shape) # Compute softmax probabilities - softmax = exp_scores / sum_exp_scores - - if DEBUG: - print("softmax:", softmax, softmax.shape) - + p = exp_scores / sum_exp_scores + + if DEBUG_CORE: + print("softmax:", p, p.shape) + + # apply dropout if specified + if dropout_p > 0.0: + rand_vals = torch.rand(p.shape, generator=torch.Generator(device=p.device).manual_seed(philox_seed), device=p.device, dtype=p.dtype) + dropout_mask, dropout_scale = rand_vals > dropout_p, (1.0 / (1 - dropout_p)) + if DEBUG_CORE: + print("dropout_scale:", dropout_scale) + print("dropout_mask:", dropout_mask) + # Apply dropout mask and scale + # Set -1 for dropped positions and 1 for kept positions in exp_scores + sd_mask = torch.where(dropout_mask, exp_scores, -exp_scores) + p = torch.where(dropout_mask, p , torch.zeros_like(p)) * dropout_scale + if DEBUG_CORE: + print("softmax after dropout:", p) + print("sd_mask:", sd_mask) + else: + sd_mask = exp_scores + # Compute log-sum-exp if use_exp2: LN2 = math.log(2) @@ -99,17 +142,22 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): softmax_lse = max_scores + torch.log(sum_exp_scores) softmax_lse = softmax_lse.squeeze(-1) - if DEBUG: + if DEBUG_CORE: print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16) - if DEBUG: + o = torch.matmul(p, v) + if DEBUG_CORE: print("o:", o, o.shape) - return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + # cast back to original dtype + o = o.to(torch.float16) + # softmax_lse = softmax_lse.to(torch.float16) # NOTE: if you cast lse to fp16 it cause accuracy issues. keep fp32 + sd_mask = sd_mask.to(torch.float16) + + return o, softmax_lse, sd_mask -def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, use_exp2): +def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2): """Compute reference output and softmax_lse using PyTorch's built-in function""" # Ensure the layout is 'bhsd' @@ -120,34 +168,54 @@ def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout elif layout != "bhsd": raise ValueError(f"Unknown layout {layout}") - # Prepare tensors in [batch_size * num_heads, seq_len, head_dim] format - batch_size, num_heads, seq_len_q, head_dim = q.shape - seq_len_k = k.shape[2] - - # Merge batch and heads dimensions - q = q.reshape(batch_size * num_heads, seq_len_q, head_dim) - k = k.reshape(batch_size * num_heads, seq_len_k, head_dim) - v = v.reshape(batch_size * num_heads, seq_len_k, head_dim) + # Prepare tensors + batch_size, nheads_q, seq_len_q, head_dim = q.shape + batch_size, nheads_k, seq_len_k, head_dim = k.shape + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") + + if group_size != 1: + # MQA or GQA case + # Reshape q to [batch_size, nheads_k, group_size, seq_len_q, head_dim] + q = q.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + # Expand k and v to match group_size + k = k.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + v = v.unsqueeze(2).expand(-1, -1, group_size, -1, -1) + # Flatten the first three dimensions for computation + q = q.reshape(batch_size * nheads_k * group_size, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k * group_size, seq_len_k, head_dim) + else: + q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim) + k = k.reshape(batch_size * nheads_k, seq_len_k, head_dim) + v = v.reshape(batch_size * nheads_k, seq_len_k, head_dim) # Call the core attention function - o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores = attention_forward_core_ref_impl( - q, k, v, sm_scale, causal, use_exp2 + o, softmax_lse, sd_mask = attention_forward_core_ref_impl( + q, k, v, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2 ) - # Reshape outputs back to [batch_size, num_heads, seq_len, head_dim] - o = o.reshape(batch_size, num_heads, seq_len_q, head_dim) - softmax_lse = softmax_lse.reshape(batch_size, num_heads, seq_len_q) - exp_scores = exp_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - softmax = softmax.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_shifted_scaled_scores = attention_shifted_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_scaled_scores = attention_scaled_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) - attention_scores = attention_scores.reshape(batch_size, num_heads, seq_len_q, seq_len_k) + if group_size != 1: + # Reshape outputs back to original dimensions + o = o.reshape(batch_size, nheads_k, group_size, seq_len_q, head_dim) + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_k, group_size, seq_len_q) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_k, group_size, seq_len_q, seq_len_k) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) + else: + # Standard case + o = o.reshape(batch_size, nheads_q, seq_len_q, head_dim) + softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q) + sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k) # Restore original layout if necessary if layout == "bshd": o = o.transpose(1, 2) - return o, softmax_lse, exp_scores, softmax, attention_shifted_scaled_scores, attention_scaled_scores, attention_scores + return o, softmax_lse, sd_mask + def attention_varlen_forward_pytorch_ref_impl( q, @@ -160,6 +228,10 @@ def attention_varlen_forward_pytorch_ref_impl( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2 ): # Ensure the layout is 'thd' @@ -167,15 +239,21 @@ def attention_varlen_forward_pytorch_ref_impl( raise ValueError(f"Unsupported layout {layout}. Expected 'thd'.") batch_size = cu_seqlens_q.shape[0] - 1 - num_heads = q.shape[1] + nheads_q, nheads_k = q.shape[1], k.shape[1] head_dim = q.shape[2] # Pre-allocate outputs total_L_q = q.shape[0] total_L_k = k.shape[0] - o = torch.empty((total_L_q, num_heads, head_dim), dtype=q.dtype, device=q.device) - softmax_lse = torch.empty((total_L_q, num_heads), dtype=torch.float32, device=q.device) + o = torch.zeros((total_L_q, nheads_q, head_dim), dtype=q.dtype, device=q.device) + softmax_lse = torch.zeros((total_L_q, nheads_q), dtype=torch.float32, device=q.device) + sd_mask = torch.zeros((batch_size, nheads_q, max_seqlen_q, max_seqlen_k), dtype=torch.float32, device=q.device) + + # Compute group_size for MQA/GQA handling + group_size = nheads_q // nheads_k + if nheads_q % nheads_k != 0: + raise ValueError("nheads_q must be divisible by nheads_k") for i in range(batch_size): # Get the start and end indices for the current sequence @@ -184,136 +262,126 @@ def attention_varlen_forward_pytorch_ref_impl( start_k = cu_seqlens_k[i].item() end_k = cu_seqlens_k[i + 1].item() + seqlen_q = end_q - start_q + seqlen_k = end_k - start_k + + if DEBUG: + print(f"Batch {i} with seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, Hq= {nheads_q}, Hk = {nheads_k}") + # Extract q_i, k_i, v_i - q_i = q[start_q:end_q, :, :] # [L_q_i, num_heads, head_dim] - k_i = k[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] - v_i = v[start_k:end_k, :, :] # [L_k_i, num_heads, head_dim] + q_i = q[start_q:end_q, :, :] # [L_q_i, nheads_q, head_dim] + k_i = k[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] + v_i = v[start_k:end_k, :, :] # [L_k_i, nheads_k, head_dim] - # Permute to [num_heads, L_q_i, head_dim] + # Permute to [nheads, L_q_i, head_dim] q_i = q_i.permute(1, 0, 2) k_i = k_i.permute(1, 0, 2) v_i = v_i.permute(1, 0, 2) + # Handle MQA/GQA by adjusting shapes based on group_size + if group_size != 1: + # Reshape q_i to [nheads_k, group_size, L_q_i, head_dim] + q_i = q_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Expand k_i and v_i to match group_size + k_i = k_i.unsqueeze(1).expand(-1, group_size, -1, -1) + v_i = v_i.unsqueeze(1).expand(-1, group_size, -1, -1) + # Flatten the first two dimensions for computation + q_i = q_i.reshape(nheads_k * group_size, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k * group_size, seqlen_k, head_dim) + else: + # Standard case + q_i = q_i.reshape(nheads_q, seqlen_q, head_dim) + k_i = k_i.reshape(nheads_k, seqlen_k, head_dim) + v_i = v_i.reshape(nheads_k, seqlen_k, head_dim) + + if alibi_slopes is not None: + alibi_slopes_i = alibi_slopes[i] + else: + alibi_slopes_i = None + # Call the core attention function for this sequence - ( - o_i, - softmax_lse_i, - exp_scores_i, - softmax_i, - attention_shifted_scaled_scores_i, - attention_scaled_scores_i, - attention_scores_i, - ) = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, use_exp2) - - # Convert back to 'thd' layout and float16 - o_i = o_i.permute(1, 0, 2).to(torch.float16) # [L_q_i, num_heads, head_dim] + o_i, softmax_lse_i, sd_mask_i = attention_forward_core_ref_impl(q_i, k_i, v_i, sm_scale, causal, dropout_p, philox_seed, philox_offset, alibi_slopes_i, use_exp2) + + # Reshape outputs back to original dimensions + if group_size != 1: + # Reshape outputs to [nheads_k, group_size, seqlen_q, head_dim] + o_i = o_i.reshape(nheads_k, group_size, seqlen_q, head_dim) + # Combine the first two dimensions back to nheads_q + o_i = o_i.reshape(nheads_q, seqlen_q, head_dim) + # Reshape softmax_lse_i similarly + softmax_lse_i = softmax_lse_i.reshape(nheads_k, group_size, seqlen_q) + softmax_lse_i = softmax_lse_i.reshape(nheads_q, seqlen_q) + else: + # Outputs are already in the correct shape + pass + + # Convert back to 'thd' layout + o_i = o_i.permute(1, 0, 2) # [L_q_i, nheads_q, head_dim] + softmax_lse_i = softmax_lse_i.permute(1, 0) # [L_q_i, nheads_q] + sd_mask_i = sd_mask_i # [nheads_q, L_q_i, L_k_i] # Place outputs in pre-allocated tensors o[start_q:end_q, :, :] = o_i - softmax_lse[start_q:end_q, :] = softmax_lse_i.transpose(0, 1) # Transpose to [L_q_i, num_heads] - - # For variable-sized outputs, map them into the preallocated tensors - # exp_scores_i: [num_heads, L_q_i, L_k_i] -> [L_q_i, num_heads, L_k_i] - exp_scores_i = exp_scores_i.permute(1, 0, 2) - softmax_i = softmax_i.permute(1, 0, 2) - attention_shifted_scaled_scores_i = attention_shifted_scaled_scores_i.permute(1, 0, 2) - attention_scaled_scores_i = attention_scaled_scores_i.permute(1, 0, 2) - attention_scores_i = attention_scores_i.permute(1, 0, 2) - - return ( - o, - softmax_lse, - None, - None, - None, - None, - None, - ) + softmax_lse[start_q:end_q, :] = softmax_lse_i + sd_mask[i, :, :seqlen_q, :seqlen_k] = sd_mask_i + + return o, softmax_lse, sd_mask -def attention_forward_pytorch_ref_impl( - q, - k, - v, - sm_scale, - causal, - layout, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2 - ): - if DEBUG: - print() - print("attention_forward_pytorch_ref_impl") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("sm_scale:", sm_scale) - print("causal:", causal) - print("cu_seqlens_q:", cu_seqlens_q) - print("cu_seqlens_k:", cu_seqlens_k) - print("max_seqlen_q:", max_seqlen_q) - print("max_seqlen_k:", max_seqlen_k) - print("use_exp2:", use_exp2) - # compute reference +def attention_forward_pytorch_ref_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + use_exp2: bool +): + # compute reference if layout == "thd": - ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_varlen_forward_pytorch_ref_impl( + o_ref, softmax_lse_ref, sd_mask_ref = attention_varlen_forward_pytorch_ref_impl( q.clone(), k.clone(), v.clone(), sm_scale, - causal, + causal, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, use_exp2, ) else: - ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_vanilla_forward_pytorch_ref_impl( - q.clone(), k.clone(), v.clone(), sm_scale, causal, layout, use_exp2 - ) - - if DEBUG: - print() - print("attention_forward_pytorch_ref_impl outputs") - print("o_ref:", o_ref, o_ref.shape) - print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape if exp_scores_ref is not None else None) - - return ( - o_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) - - -def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): - q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) - k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) - relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) - return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) \ No newline at end of file + o_ref, softmax_lse_ref, sd_mask_ref = attention_vanilla_forward_pytorch_ref_impl( + q.clone(), + k.clone(), + v.clone(), + sm_scale, + causal, + layout, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + use_exp2) + + # copy back to ouput tensor + out.copy_(o_ref.to(out.dtype)) + + return softmax_lse_ref, sd_mask_ref diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 59a306d5d6..bb6e25b509 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -2,34 +2,43 @@ import os from .fwd_prefill import attention_prefill_forward_triton_impl from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_fused import _flash_attn_backward as attention_prefill_backward_triton_fused_impl +from .bwd_prefill_onekernel import attention_prefill_backward_triton_split_oneKernel_impl from .fwd_decode import attention_decode_forward_triton_impl from .fwd_ref import attention_forward_pytorch_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .utils import MetaData, get_shape_from_layout, DEBUG - -USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') - -def fwd(q, - k, - v, - o, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen_): - +from .utils import DEBUG, USE_REF, MetaData, get_shapes_from_layout, is_fp8 +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +from typing import Literal, Optional, Union + +def fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): + if DEBUG: print() - print("flash_attn_triton_amd.py::fwd") + print("flash_attn_triton_amd.py::fwd inputs") print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) - print("o:", o) + print("out:", out, out.shape if out is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("softmax_scale:", softmax_scale) @@ -37,15 +46,17 @@ def fwd(q, print("window_size_left:", window_size_left) print("window_size_right:", window_size_right) print("softcap:", softcap) - print("softcap:", softcap) print("return_softmax:", return_softmax) - - - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD's Triton Backend yet") - - if o is None: - o = torch.empty_like(q) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) + + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + else: + out = torch.zeros_like(q) if out is None else out.zero_() # Setup metadata metadata = MetaData(sm_scale=softmax_scale) @@ -55,111 +66,129 @@ def fwd(q, if return_softmax: metadata.return_scores = True - batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout) - + batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, metadata.layout) + if causal: - metadata.need_causal() - + metadata.need_causal(True) + if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - + if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - - # Check arguments - metadata.check_args(q, k, v, o) + metadata.need_dropout(dropout_p) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + else: + rng_state = None + + # check arguments + metadata.check_args(q, k, v, out) + + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - (output, - softmax_lse, - exp_scores, - _, - _, - _, - _) = attention_forward_pytorch_ref_impl( - q, - k, + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q, + k, v, - metadata.sm_scale, + out, + metadata.sm_scale, + metadata.alibi_slopes, metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, + metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.use_exp2) - o.copy_(output) + softmax_lse=softmax_lse_ref + sd_mask=sd_mask_ref else: if DEBUG: print("Using Triton implementation") - (_, - softmax_lse, - exp_scores, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k, + v, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_o) + softmax_lse=softmax_lse_triton + sd_mask=sd_mask_triton if DEBUG: - print("fwd outputs") - print("o:", o, o.shape) + print("flash_attn_triton_amd.py::fwd outputs") + print("o:", out, out.shape) + if is_fp8(out): + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None ) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - return o, softmax_lse, exp_scores, None + return out, softmax_lse, sd_mask, rng_state +BWD_MODE = os.environ.get('BWD_MODE', 'split').lower() def bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - alibi_slopes, - dropout_p, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen_, - rng_state, + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state:Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + descale_dq: Optional[torch.Tensor] = None, + descale_dk: Optional[torch.Tensor] = None, + descale_dv: Optional[torch.Tensor] = None, ): if DEBUG: print() - print("flash_attn_triton_amd.py::bwd") + print("flash_attn_triton_amd.py::bwd inputs") print("dout:", dout, dout.shape) print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) print("out:", out, out.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) print("alibi_slopes:", alibi_slopes) print("dropout_p:", dropout_p) print("out:", out) @@ -170,37 +199,31 @@ def bwd( print("deterministic:", deterministic) print("gen_:", gen_) print("rng_state:", rng_state) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) + print("descale_do:", descale_do, descale_do.shape if descale_do is not None else None) + print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) + print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) + print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD yet") + if dropout_p > 0.0: + assert rng_state is not None + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( - dout, - q, - k, - v, - out, - softmax_lse, - softmax_scale, - causal, - "bshd", - None, - None, - None, - None, - False, - ) - dq.copy_(dq_ref) - dk.copy_(dk_ref) - dv.copy_(dv_ref) - delta = delta_ref - else: - if DEBUG: - print("Using Triton implementation") - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( + + delta_ref = attention_backward_pytorch_ref_impl( dout, q, k, @@ -218,39 +241,144 @@ def bwd( None, None, None, + dropout_p, + philox_seed, + philox_offset, False, ) - delta = delta_triton + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + False, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, + ) + delta = delta_triton + elif BWD_MODE == "fused": + delta_triton = attention_prefill_backward_triton_fused_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + None, + None, + q.shape[1], + k.shape[1], + dropout_p, + philox_seed, + philox_offset, + descale_q, + descale_k, + descale_v, + descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "jingning": + delta_triton = attention_prefill_backward_triton_split_oneKernel_impl( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + alibi_slopes, + causal, + "bshd", + None, + None, + None, + None, + dropout_p, + philox_seed, + philox_offset, + False + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") if DEBUG: - print("bwd outputs") + print("flash_attn_triton_amd.py::bwd outputs") print("dv:", dv, dv.shape) + if is_fp8(dv): + print("descale_dv:", descale_dv, descale_dv.shape if descale_dv is not None else None) print("dk:", dk, dk.shape) + if is_fp8(dk): + print("descale_dk:", descale_dk, descale_dk.shape if descale_dk is not None else None) print("dq:", dq, dq.shape) + if is_fp8(dq): + print("descale_dq:", descale_dq, descale_dq.shape if descale_dq is not None else None) return dq, dk, dv, delta def varlen_fwd( - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - leftpad_k, - block_table_, - alibi_slopes,\ - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - return_softmax, - gen_): + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool , + causal: bool , + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None + ): if DEBUG: print() @@ -269,120 +397,137 @@ def varlen_fwd( print("window_size_left:", window_size_left) print("window_size_right:", window_size_right) print("gen_:", gen_) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD's Triton Backend yet") - - if o is None: - o = torch.empty_like(q) + if is_fp8(q): + assert out is not None, "fp8 output tensor should be passed in." + assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + else: + out = torch.zeros_like(q) if out is None else out.zero_() # Setup metadata metadata = MetaData(sm_scale=softmax_scale) if return_softmax: metadata.return_scores = True - metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) # set layout to "thd" and other metdata + assert metadata.layout is not None # get shapes - batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shapes_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) if causal: - metadata.need_causal() + metadata.need_causal(True) if alibi_slopes is not None: metadata.need_alibi(alibi_slopes, batch, nheads_q) - + if dropout_p > 0.0: - metadata.need_dropout(dropout_p, return_softmax) - + metadata.need_dropout(dropout_p) + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) # as_tensors uses the underlying data and doesnot cast + else: + rng_state = None + # Check arguments - metadata.check_args(q, k, v, o) - if o is None: - o = torch.empty_like(q, dtype=v.dtype) + metadata.check_args(q, k, v, out) + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - (output, - softmax_lse, - exp_scores, - _, - _, - _, - _) = attention_forward_pytorch_ref_impl( - q, - k, + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q, + k, v, - metadata.sm_scale, + out, + metadata.sm_scale, + metadata.alibi_slopes, metadata.causal, - metadata.layout, - metadata.cu_seqlens_q, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, + metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.use_exp2) - o.copy_(output) + softmax_lse=softmax_lse_ref + sd_mask=sd_mask_ref else: if DEBUG: print("Using Triton implementation") - (_, - softmax_lse, - exp_scores, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k, + v, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + descale_q, + descale_k, + descale_v, + descale_o) + softmax_lse=softmax_lse_triton + sd_mask=sd_mask_triton + if DEBUG: print("varlen_fwd outputs") - print("o:", o, o.shape) + print("out:", out, out.shape) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None ) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) - return o, softmax_lse, exp_scores, None + return out, softmax_lse, sd_mask, rng_state def varlen_bwd( - dout, - q, - k, - v, - out, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - zero_tensors, - causal, - window_size_left, - window_size_right, - softcap, - deterministic, - gen_, - rng_state, + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_ : Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, + descale_v: Optional[torch.Tensor] = None, + descale_o: Optional[torch.Tensor] = None, + descale_do: Optional[torch.Tensor] = None, + descale_dq: Optional[torch.Tensor] = None, + descale_dk: Optional[torch.Tensor] = None, + descale_dv: Optional[torch.Tensor] = None, ): if DEBUG: print() @@ -391,17 +536,17 @@ def varlen_bwd( print("q:", q, q.shape) print("k:", k, k.shape) print("v:", v, v.shape) + print("out:", out) print("softmax_lse:", softmax_lse, softmax_lse.shape) - print("dq:", dq, dq.shape) - print("dk:", dk, dk.shape) - print("dv:", dv, dv.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) print("alibi_slopes:", alibi_slopes) print("max_seqlen_q:", max_seqlen_q) print("max_seqlen_k:", max_seqlen_k) print("dropout_p:", dropout_p) - print("out:", out) print("softmax_scale:", softmax_scale) print("causal:", causal) print("window_size_left:", window_size_left) @@ -409,37 +554,53 @@ def varlen_bwd( print("deterministic:", deterministic) print("gen_:", gen_) print("rng_state:", rng_state) + print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) + print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) + print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) + print("descale_do:", descale_do, descale_do.shape if descale_do else None) - if dropout_p != 0.0: - raise ValueError("dropout is not supported on AMD yet") + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + if dropout_p > 0.0: + assert rng_state is not None + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + # call implementation if USE_REF: if DEBUG: print("Using reference implementation") - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + delta_ref = attention_backward_pytorch_ref_impl( dout, q, k, v, out, softmax_lse, + dq, + dk, + dv, softmax_scale, + alibi_slopes, causal, "thd", cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, ) - dq.copy_(dq_ref) - dk.copy_(dk_ref) - dv.copy_(dv_ref) delta = delta_ref else: if DEBUG: - print("Using Triton implementation") - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( + print("Using Triton implementation") + delta_triton = attention_prefill_backward_triton_split_impl( dout, q, k, @@ -457,7 +618,18 @@ def varlen_bwd( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, False, + descale_q, + descale_k, + descale_v, + descale_o, + descale_do, + descale_dq, + descale_dk, + descale_dv, ) delta = delta_triton @@ -471,29 +643,54 @@ def varlen_bwd( return dq, dk, dv, delta def fwd_kvcache( - q, - k_cache, - v_cache, - k, - v, - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - out, - softmax_scale, - causal, - window_size_left, - window_size_right, - softcap, - rotary_interleaved, - num_splits): - - if out is None: - out = torch.empty_like(q) + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + cache_leftpad: Optional[torch.Tensor], + block_table: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + out: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + rotary_interleaved: bool, + num_splits: int + ): + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache inputs") + print("q:", q, q.shape) + print("k_cache:", k_cache, k_cache.shape) + print("v_cache:", v_cache, v_cache.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("cache_seqlens:", cache_seqlens ) + print("rotary_cos:",rotary_cos ) + print("rotary_sin:",rotary_sin) + print("cache_batch_idx:", cache_batch_idx) + print("cache_leftpad:", cache_leftpad) + print("block_table:", block_table) + print("alibi_slopes:", alibi_slopes) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("num_splits:", num_splits) + + # output + out = torch.zeros_like(q) if out is None else out.zero_() # fill metadata metadata = MetaData(sm_scale=softmax_scale) @@ -503,33 +700,99 @@ def fwd_kvcache( metadata.cache_seqlens = cache_seqlens metadata.cache_batch_idx = cache_batch_idx - if k is not None and v is not None: - metadata.new_kv = True - metadata.seqlen_new = k.shape[1] - metadata.k_new = k - metadata.v_new = v + k_new = k + v_new = v if causal: - metadata.need_causal() + metadata.need_causal(True) if alibi_slopes is not None: batch, _ , nheads_q, _= q.shape metadata.need_alibi(alibi_slopes, batch, nheads_q) + # rotary boolean + apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin) + if apply_rotary: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Rotary Embedding Implementation + if apply_rotary: + if metadata.causal: # NOTE: when support is added. Add `or metadata.local` + q_ro = apply_rotary_emb( + q, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=metadata.max_seqlens_q, + ) + k_ro = apply_rotary_emb( + k_new, + metadata.rotary_cos, + metadata.rotary_sin, + seqlen_offsets=metadata.cache_seqlens, + interleaved=metadata.rotary_interleaved, + ) + + q, k_new = q_ro.to(q.dtype), k_ro.to(q.dtype) + # launch kernel - # TODO: pass output as an arg. Maybe we are copying output which is causing slow down - output, softmax_lse = attention_decode_forward_triton_impl( - q, - k_cache, - v_cache, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.new_kv, - metadata.k_new, - metadata.v_new, - ) - return output, softmax_lse + DECODE_KERNEL= True # os.environ.get('DECODE_KERNEL', '0').lower() in ('1', 'true', 'yes') + if DECODE_KERNEL: + softmax_lse_triton = attention_decode_forward_triton_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + metadata.alibi_slopes, + metadata.layout, + metadata.cache_seqlens, + metadata.cache_batch_idx, + ) + else: + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, + k_cache, + v_cache, + out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + None, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_scores, + metadata.use_exp2, + None, + None, + None, + None) + softmax_lse = softmax_lse_triton + + if DEBUG: + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + return out, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_torch.py b/flash_attn/flash_attn_triton_amd/interface_torch.py deleted file mode 100644 index d4906606ed..0000000000 --- a/flash_attn/flash_attn_triton_amd/interface_torch.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl -from .fwd_decode import attention_decode_forward_triton_impl - - -class _attention_prefill(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, o, metadata): - (output, - softmax_lse, - exp_scores, - grid, - head_size, - philox_seed, - philox_offset, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - o, - metadata.sm_scale, - metadata.alibi_slopes, - metadata.causal, - metadata.bias, - metadata.dropout_p, - metadata.layout, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - metadata.return_scores, - metadata.use_exp2) - - ctx.save_for_backward(q, k, v, o, softmax_lse) - ctx.grid = grid - ctx.sm_scale = metadata.sm_scale - ctx.head_size = head_size - ctx.causal = metadata.causal - ctx.alibi_slopes = metadata.alibi_slopes - ctx.dropout_p = metadata.dropout_p - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.exp_scores = exp_scores - ctx.return_scores = metadata.return_scores - ctx.layout = metadata.layout - ctx.use_exp2 = metadata.use_exp2 - return output, softmax_lse, exp_scores - - @staticmethod - def backward(ctx, do, *args): - q, k, v, o, softmax_lse = ctx.saved_tensors - return attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - None, - None, - None, - ctx.sm_scale, - ctx.alibi_slopes, - ctx.causal, - ctx.layout, - None, - None, - None, - None, - ctx.use_exp2 - ) - -attention_prefill = _attention_prefill.apply - - -class _attention_decode(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, metadata): - output, softmax_lse = attention_decode_forward_triton_impl( - q, - k, - v, - metadata.sm_scale, - metadata.causal, - metadata.alibi_slopes, - metadata.layout, - metadata.cache_seqlens, - metadata.cache_batch_idx, - metadata.new_kv, - metadata.k_new, - metadata.v_new, - ) - return output, softmax_lse - -attention_decode = _attention_decode.apply diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 9a6ab8dab2..58e2ae5fc7 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -1,617 +1,348 @@ +import os +import glob +import shutil +import time import torch import pytest - -from .utils import MetaData, get_input_shapes, input_helper, varlen_input_helper, DEBUG -from .interface_torch import attention_prefill, attention_decode -from .fwd_ref import attention_forward_pytorch_ref_impl, compute_alibi_tensor_ref +import logging +import numpy as np +from pathlib import Path +from flash_attn import ( + flash_attn_func, + flash_attn_fp8_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_qkvpacked_fp8_func, + flash_attn_varlen_func, + flash_attn_varlen_fp8_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_qkvpacked_fp8_func +) + +from .utils import DEBUG, input_helper, arch_supports_fp8 +from .fwd_ref import attention_forward_pytorch_ref_impl from .fwd_prefill import attention_prefill_forward_triton_impl -from .bwd_prefill import attention_prefill_backward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl from .bwd_ref import attention_backward_pytorch_ref_impl -from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 + +# set print options +# torch.set_printoptions(linewidth=5e5, edgeitems=10, sci_mode=False) +# np.set_printoptions(linewidth=5000, threshold=1e4, suppress=True, precision=4) # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. # ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues # ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs # ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs +# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 +ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 +# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 EQUAL_NAN = True -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 24, 1024, 1024, 64), - (1, 24, 6, 8192, 8192, 64), - (1, 4, 2, 16384, 16384, 128), - (2, 16, 4, 1020, 987, 128), - (2, 16, 4, 15498, 2, 128), - (2, 16, 2, 7, 16219, 64), - (4, 48, 12, 1, 1, 64), - (4, 48, 48, 1, 1, 128), - (4, 48, 24, 3, 3, 128), - (4, 48, 48, 1001, 990, 64), - (1, 8, 8, 8081, 7099, 64), - (1, 4, 4, 16330, 15989, 128), - (4, 4, 1, 1024, 1024, 33), - (4, 4, 2, 65, 1018, 65), - (4, 4, 4, 128, 128, 65), - (4, 4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) -def test_op_fwd_prefill(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): - torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) - if causal: - input_metadata.need_causal() - - if use_alibi: - # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, HQ) - else: - alibi_slopes = None - - o = torch.empty_like(q) - - # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - - # Transpose here if layout is bshd so we have same reference code for all layouts - if layout == 'bshd': - q = q.transpose(1, 2).clone() - k = k.transpose(1, 2).clone() - v = v.transpose(1, 2).clone() - # Replicate K and V if using MQA/GQA - if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) - - scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale - if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") - if use_alibi: - scores += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) - - p = torch.softmax(scores, dim=-1) - if causal: - # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into - # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix - # this by converting the NaNs to 0s, which is what they should be out of the softmax. - nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) - # compare - if layout == 'bshd': - ref_out = ref_out.transpose(1, 2).clone() - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1024, 1024, 64), - (4, 12, 8192, 8192, 64), - (2, 4, 16384, 16384, 128), - (2, 16, 15498, 2, 128), - (2, 4, 7, 16219, 64), - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 48, 1001, 990, 64), - (1, 8, 8081, 7099, 64), - (1, 8, 16330, 15989, 128), - (4, 4, 1024, 1024, 33), - (4, 4, 65, 1019, 65), - (4, 4, 128, 128, 65), - # TODO: This config fails. Disabled until triaged and fixed. - # (2, 16, 1020, 987, 128), - # (4, 4, 113, 123, 1), -]) +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), + ], +) @pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_bias', [True]) -def test_op_fwd_prefill_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): - torch.manual_seed(20) - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') - if causal: - input_metadata.need_causal() - if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") - input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) - else: - bias = None - o = torch.empty_like(q) - - # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - # reference implementation:171 - - scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale - if causal: - mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) - scores[:, :, mask == 0] = float("-inf") - if use_bias: - scores += input_metadata.bias - p = torch.softmax(scores, dim=-1) - if causal: - # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into - # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix - # this by converting the NaNs to 0s, which is what they should be out of the softmax. - nan_mask = torch.isnan(p) - p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) - # compare - torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) - - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ - (4, 48, 8192, 64), - (4, 48, 256, 64), - (4, 48, 512, 64), - (4, 48, 1024, 64), - (8, 48, 4096, 64), - (4, 48, 8192, 64), - (4, 48, 128, 128), - (4, 48, 4096, 128), - (4, 48, 16384, 128), - (4, 16, 1024, 128), - (4, 16, 8192, 128), - (32, 48, 8192, 128) - ] - ) -@pytest.mark.parametrize('causal', [True, False]) -def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) - - tri_out = torch.empty_like(q) - ref_out = torch.empty_like(q) - - for i in range(0, input_metadata.num_contexts): - start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() - p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) - attention_prefill(q, k, v, tri_out, input_metadata) - torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), - (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), - (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), - (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), - (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) -@pytest.mark.parametrize('causal', [False]) -def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) - ref_out = torch.empty_like(q) - tri_out = torch.empty_like(q) - # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the - # size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) - v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) - for i in range(0, input_metadata.num_contexts): - start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] - end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] - k_curr = k_ref[start_k:end_k] - k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) - v_curr = v_ref[start_k:end_k] - v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() - p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) - attention_prefill(q, k, v, tri_out, input_metadata) - torch.testing.assert_close(ref_out, tri_out, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - # smallest config test - (1, 1, 16, 16, 64), # pass on new # fail on old - (1, 1, 32, 32, 64), # pass on new # fail on old - (1, 1, 64, 64, 16), # pass # smallest head_size = 16 - (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 - (1, 1, 128, 128, 64), # pass - (1, 1, 256, 256, 64), # pass - (1, 1, 512, 512, 64), # pass - # failing FA - (1, 1, 256, 512, 16), - # old tests that work - (4, 48, 1024, 1024, 64), # pass - (4, 48, 2048, 2048, 64), # pass - (2, 48, 4096, 4096, 64), # pass - (1, 16, 1024, 1024, 64), # pass - (1, 16, 1024, 1024, 128), # pass - # old tests that were commented out - # (1, 16, 8192, 8192, 63), - # (1, 16, 1022, 1022, 64), -]) -# @pytest.mark.parametrize('torch_sdpa_test', [False, True]) -@pytest.mark.parametrize('torch_sdpa_test', [False]) -# @pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('causal', [False]) -# @pytest.mark.parametrize('use_alibi', [False, True]) -@pytest.mark.parametrize('use_alibi', [False]) -def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): - torch.manual_seed(20) - - DEBUG_INPUT = False - - # seqlens - seqlen_q = N_CTX_Q - seqlen_k = N_CTX_K - - # setup up metadata - if DEBUG_INPUT: - sm_scale = 1 - else: - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = seqlen_q - input_metadata.max_seqlens_k = seqlen_k - input_metadata.layout = "bhsd" - - dropout_p = 0 - if DEBUG_INPUT: - q = torch.arange(seqlen_q, dtype=dtype, device="cuda").view(1, 1, seqlen_q, 1).expand(Z, H, seqlen_q, D_HEAD).requires_grad_() - k = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() - v = torch.arange(seqlen_k, dtype=dtype, device="cuda").view(1, 1, seqlen_k, 1).expand(Z, H, seqlen_k, D_HEAD).requires_grad_() - o = torch.zeros_like(q) - else: - # Generate random inputs - q = torch.randn(Z, H, N_CTX_Q, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - k = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) - o = torch.empty_like(q) - - if causal: - input_metadata.need_causal() - - if use_alibi and not torch_sdpa_test: - # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) - alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) - input_metadata.need_alibi(alibi_slopes, Z, H) - - if DEBUG_INPUT: - dout = torch.ones_like(q) - else: - dout = torch.randn_like(q) - - # reference implementation - if torch_sdpa_test: - ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, - is_causal=causal, scale=sm_scale, - dropout_mask=None) - ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - else: - M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - if use_alibi: - p += compute_alibi_tensor_ref(alibi_slopes, N_CTX_Q, N_CTX_K) - if causal: - p[:, :, M == 0] = float("-inf") - - p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - - # # triton implementation - tri_out, _, _ = attention_prefill(q, k, v, o, input_metadata) - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - if DEBUG: - print("tri_out:", tri_out) - print("ref_out:",ref_out ) - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) - - # The current block size for MI200 series is 64x64. This results in - # larger differences in float results due to rounding. - if dtype == torch.bfloat16: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - if dtype == torch.float32: - ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - else: - ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) - - RTOL = 0 - - if DEBUG: - print("ref_dv:", ref_dv) - print("tri_dv:", tri_dv) - print("ref_dk:", ref_dk) - print("tri_dk:", tri_dk) - print("ref_dq:", ref_dq) - print("tri_dq:", tri_dq) - - torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) - torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) - torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 2, 4, 16), - (1, 1, 4, 2, 16), - (1, 1, 4, 4, 16), - (1, 2, 4, 4, 16), - (2, 1, 4, 4, 16), - (2, 2, 4, 4, 16), - (1, 1, 128, 64, 16), - (2, 2, 2, 128, 1), - (2, 3, 2, 128, 16), - (3, 2, 256, 512, 16), - (3, 3, 128, 128, 64), - (2, 4, 1024, 1024, 64), - (4, 6, 108, 256, 224), - (4, 8, 2048, 2048, 128), - (4, 16, 4096, 4096, 64), - (2, 4, 8192, 8192, 32), - # # fa configs - (4, 6, 113, 203, 256), - (4, 6, 128, 217, 256), - (4, 6, 113, 211, 128), - (4, 6, 108, 256, 128), - (4, 6, 256, 512, 64), - (4, 6, 512, 256, 64), - (4, 6, 1024, 1024, 32), - (4, 6, 1023, 1024, 32), - (4, 6, 1024, 1023, 32), - (4, 6, 2048, 2048, 32), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('return_scores', [False]) -@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('alibi_slopes', [None]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false @pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues -def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scores, layout, use_exp2, DEBUG_INPUT): - dtype = torch.float16 - torch.manual_seed(0) - alibi_slopes = None - dropout_p = 0.0 +def test_op_prefill_fwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): + torch.manual_seed(42) device = "cuda" - if layout == "thd": - q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) - else: - q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) - if DEBUG_INPUT: - output_triton = torch.zeros_like(q).contiguous() - else: - output_triton = torch.empty_like(q) + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) + + if DEBUG: + if HQ // HK != 1: + print("MQA/GQA") + else: + print("MHA") # update metadata metadata.use_exp2 = use_exp2 if causal: - metadata.need_causal() + metadata.need_causal(True) # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that - if return_scores: - metadata.return_scores = True + metadata.need_dropout(dropout_p) + # call Triton's forward implementation directly - ( output_triton, - softmax_lse_triton, - exp_scores_triton, - _, - _, - _, - _, - _, - _) = attention_prefill_forward_triton_impl( - q, - k, - v, - output_triton, + q_triton = q.clone() + k_triton = k.clone() + v_triton = v.clone() + o_triton = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q_triton, + k_triton, + v_triton, + o_triton, metadata.sm_scale, metadata.alibi_slopes, metadata.causal, metadata.bias, - metadata.dropout_p, metadata.layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, - metadata.max_seqlens_k, + metadata.max_seqlens_k, + metadata.cache_seqlens, + metadata.cache_batch_idx, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, metadata.return_scores, - metadata.use_exp2) - - ( - output_ref, - softmax_lse_ref, - exp_scores_ref, - softmax_ref, - attention_shifted_scaled_scores_ref, - attention_scaled_scores_ref, - attention_scores_ref, - ) = attention_forward_pytorch_ref_impl( - q.clone(), - k.clone(), - v.clone(), - metadata.sm_scale, + metadata.use_exp2, + None, + None, + None, + None) + + # ref forward + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + o_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + q_ref, + k_ref, + v_ref, + o_ref, + metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) + if DEBUG: + print() + print("Compare Triton Impl with refernce Pytorch Impl") + + # this can be set to true manually or when using dropout + if metadata.return_scores: + if DEBUG: + print("sd_mask_triton:", sd_mask_triton, sd_mask_triton.shape) + print("sd_mask_ref:", sd_mask_ref, sd_mask_ref.shape) + torch.testing.assert_close(sd_mask_triton.to(sd_mask_ref.dtype), sd_mask_ref, atol=ATOL, rtol=RTOL) + if DEBUG: print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) - - if layout != "thd": - # use trick with lse to get the softmax. you need the scores but is it - softmax_triton = torch.exp(attention_scaled_scores_ref - softmax_lse_triton.unsqueeze(-1)) - if DEBUG: - print("attention_scaled_scores_ref:", attention_scaled_scores_ref, attention_scaled_scores_ref.shape) - print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) - print("softmax_triton:", softmax_triton, softmax_triton.shape) - print("softmax_ref:", softmax_ref, softmax_ref.shape) - torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) if DEBUG: - print("output_triton:", output_triton, output_triton.shape) - print("output_ref:", output_ref, output_ref.shape) - torch.testing.assert_close(output_triton, output_ref, atol=ATOL, rtol=RTOL) - - - # compare with pytorch expect thd and causal impl is different - if False and layout in ["bhsd", "bshd"] and not causal: - out_pytorch, softmax_pytorch = torch.ops.aten._scaled_dot_product_attention_math( - q.transpose(1, 2) if layout == "bshd" else q , - k.transpose(1, 2) if layout == "bshd" else k, - v.transpose(1, 2) if layout == "bshd" else v, - dropout_p=dropout_p, - is_causal=causal, scale=metadata.sm_scale, - dropout_mask=None) - out_pytorch = out_pytorch.transpose(1, 2) if layout == "bshd" else out_pytorch - - if DEBUG: - print("o:", output_triton, output_triton.shape) - print("out_pytorch:", out_pytorch, out_pytorch.shape) - torch.testing.assert_close(output_triton, out_pytorch, atol=ATOL, rtol=RTOL) - - # compare with pytorch output - if DEBUG: - print("softmax_triton:", softmax_triton, softmax_triton.shape) - print("softmax_pytorch:", softmax_pytorch, softmax_pytorch.shape) - torch.testing.assert_close(softmax_triton, softmax_pytorch.to(torch.float32), atol=ATOL, rtol=RTOL) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 4, 4, 4), - (2, 1, 4, 4, 16), - (1, 2, 4, 4, 16), - (2, 2, 4, 4, 16), - (1, 1, 4, 4, 16), - (2, 1, 4, 4 , 16), - (4, 6, 8, 8 , 16), - (1, 1, 4, 4, 32), - (1, 1, 16, 16, 16), - (1, 1, 32, 32, 16), - (1, 1, 64, 64, 16), - (1, 1, 64, 64, 64), - (1, 1, 64, 128, 32), - (1, 1, 128, 128, 64), - (1, 1, 128, 256, 45), - (1, 1, 113, 203, 192), - (1, 1, 256, 256, 64), - (1, 1, 256, 512, 16), - (1, 1, 512, 512, 64), - (1, 1, 1024, 1024, 64), + print("output_triton:", o_triton, o_triton.shape) + print("output_ref:", o_ref, o_ref.shape) + torch.testing.assert_close(o_triton, o_ref, atol=ATOL, rtol=RTOL) + +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 4, 4, 4), + (2, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 8, 1, 2, 4, 16), + (1, 16, 1, 2, 4, 16), + (1, 32, 1, 2, 4, 16), + (1, 64, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 4, 4, 16), + (2, 1, 1, 4, 4 , 16), + (4, 6, 6, 8, 8 , 16), + (1, 1, 1, 4, 4, 32), + (1, 1, 1, 16, 16, 16), + (1, 1, 1, 32, 32, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 64, 16), + (1, 1, 1, 64, 128, 16), + (1, 1, 1, 64, 64, 32), + (1, 1, 1, 64, 128, 32), + (1, 1, 1, 128, 128, 64), + (1, 1, 1, 128, 256, 45), + (1, 1, 1, 113, 203, 192), + (1, 1, 1, 256, 256, 64), + (1, 1, 1, 256, 512, 16), + (1, 1, 1, 512, 512, 64), + (1, 1, 1, 1024, 1024, 64), # fa configs - (2, 2, 128, 128, 65), - (2, 2, 128, 128, 224), - (4, 6, 108, 256, 224), - (1, 1, 256, 512, 16), + (2, 2, 2, 128, 128, 65), + (2, 2, 2, 128, 128, 224), + (4, 6, 6, 108, 256, 224), + (1, 1, 1, 256, 512, 16), # old tests that work - (4, 48, 1024, 1024, 73), - (4, 48, 1024, 1024, 64), - (4, 48, 2048, 2048, 64), - (1, 24, 4096, 4096, 64), - (1, 16, 1024, 1024, 64), - (1, 16, 1024, 1024, 128), + (4, 48, 6, 1024, 1024, 64), + (4, 48, 12, 2048, 1024, 64), + (4, 48, 24, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 64), + (4, 48, 48, 1024, 1024, 73), + (4, 48, 48, 2048, 2048, 64), + (1, 24, 24, 4096, 4096, 64), + (1, 16, 16, 1024, 1024, 64), + (1, 16, 16, 1024, 1024, 128), + # testcase new + # seqlen q == k + (1, 1, 1, 2, 2, 2), # small enough to debug + (1, 1, 1, 128, 128, 32), # only one block + (1, 1, 1, 127, 127, 32), # only one block but with masking + (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 1), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 68), # generic masking on q, k and head + (4, 1, 1, 512, 512, 128), # batch > 1 + (4, 8, 2, 512, 512, 128), # GQA + (4, 8, 2, 512, 512, 68), # non-power-of-2 head_dim + (4, 8, 2, 500, 500, 68), # comprehensive case for seqlen q == k + # seqlen q > k + (1, 1, 1, 64, 32, 8), # seqlen_q > seqlen_k + (1, 1, 1, 192, 128, 32), # seqlen_q > seqlen_k + (4, 8, 2, 1024, 512, 68), # seqlen_q < seqlen_k + (1, 1, 1, 729, 516, 68), # seqlen_q > seqlen_k + (16, 16, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k + # seqlen q < k + (1, 1, 1, 32, 64, 8), # seqlen_q > seqlen_k + (1, 1, 1, 128, 192, 32), # seqlen_q < seqlen_k + (4, 8, 2, 512, 1024, 68), # seqlen_q < seqlen_k + (1, 1, 1, 200, 413, 1), # seqlen_q < seqlen_k + (1, 1, 1, 782, 1546, 1), # seqlen_q < seqlen_k + (16, 16, 4, 1528, 2753, 68), # a comprehensive seqlen_q < seqlen_k ]) @pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('alibi_slopes', [None]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal -@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) -@pytest.mark.parametrize('sequence_parallel', [True, False]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend -def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, sequence_parallel, DEBUG_INPUT): - dtype = torch.float16 - torch.manual_seed(20) # seed from test_op_bwd +@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans on larger tensors +def test_op_prefill_bwd_impl(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, alibi_slopes, layout, dtype, use_exp2, DEBUG_INPUT): + torch.manual_seed(20) + device="cuda" - alibi_slopes = None - if layout == "thd": - q, k, v, metadata = varlen_input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, DEBUG_INPUT=DEBUG_INPUT) - else: - q, k, v, metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, DEBUG_INPUT=DEBUG_INPUT) - if DEBUG_INPUT: - do = torch.ones_like(q).contiguous() - else: - do = torch.randn_like(q) + # gen inputs + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, dtype, layout=layout, device=device) + + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + metadata.need_dropout(dropout_p) # =============================================== Reference ============================================================== + # fwd q_ref = q.clone() k_ref = k.clone() - v_ref = v.clone() - ( - o_ref, - softmax_lse_ref, - _, - _, - _, - _, - _, - ) = attention_forward_pytorch_ref_impl( + v_ref = v.clone() + output_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( q_ref, k_ref, v_ref, - metadata.sm_scale, + output_ref, + metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) - dq = torch.zeros_like(q, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros - if DEBUG_INPUT: - dk = torch.zeros_like(k, dtype=k.dtype) - dv = torch.zeros_like(v, dtype=v.dtype) - else: - dk = torch.empty_like(k, dtype=k.dtype) - dv = torch.empty_like(v, dtype=v.dtype) - + # bwd do_ref = do.clone() - dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl( + dq_ref = torch.zeros_like(q).contiguous() if DEBUG_INPUT else torch.empty_like(q) + dk_ref = torch.zeros_like(k).contiguous() if DEBUG_INPUT else torch.empty_like(k) + dv_ref = torch.zeros_like(v).contiguous() if DEBUG_INPUT else torch.empty_like(v) + delta_ref = attention_backward_pytorch_ref_impl( do_ref, q_ref, k_ref, v_ref, - o_ref, + output_ref, softmax_lse_ref, + dq_ref, + dk_ref, + dv_ref, metadata.sm_scale, + metadata.alibi_slopes, causal, layout, metadata.cu_seqlens_q, metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2 ) # =============================================== Triton ============================================================== - o = o_ref.clone().contiguous() - softmax_lse = softmax_lse_ref.clone().contiguous() - dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl( - do, - q, - k, - v, - o, - softmax_lse, - dq, - dk, - dv, + do_triton = do.clone() + q_triton = q.clone() + k_triton = k.clone() + v_triton = v.clone() + o_triton = output_ref.clone().contiguous() + softmax_lse_triton = softmax_lse_ref.clone().contiguous() + dq_triton = torch.zeros_like(q_triton, dtype=q.dtype) # NOTE: the kernel does inplace accumlation on dq so dq has to be zeros + dk_triton = torch.zeros_like(k_triton, dtype=k.dtype) if DEBUG_INPUT else torch.empty_like(k_triton, dtype=k.dtype) + dv_triton = torch.zeros_like(v_triton, dtype=v.dtype) if DEBUG_INPUT else torch.empty_like(v_triton, dtype=v.dtype) + delta_triton = attention_prefill_backward_triton_split_impl( + do_triton, + q_triton, + k_triton, + v_triton, + o_triton, + softmax_lse_triton, + dq_triton, + dk_triton, + dv_triton, metadata.sm_scale, alibi_slopes, causal, @@ -620,8 +351,18 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, use_exp2, - sequence_parallel=sequence_parallel + None, + None, + None, + None, + None, + None, + None, + None, ) # =============================================== Check ============================================================== @@ -647,78 +388,545 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l print("dq_ref:", dq_ref, dq_ref.shape) torch.testing.assert_close(dq_triton, dq_ref, atol=ATOL, rtol=RTOL, equal_nan=EQUAL_NAN) +def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): + """Assert tensors are close with tolerance for small percentage of elements""" + # standard comparison + abs_diff = torch.abs(tensor_a - tensor_b) + rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) + + # calculate elements that exceed tolerance + abs_check = abs_diff > atol + rel_check = rel_diff > rtol + failed_check = torch.logical_and(abs_check, rel_check) + + # calculate percentage of failed elements + failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 + + # if percentage is small enough, test passes + if failed_percentage <= max_diff_percentage: + return True + + # Otherwise, provide diagnostic information + max_abs_idx = torch.argmax(abs_diff).item() + max_rel_idx = torch.argmax(rel_diff).item() + + flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) + + max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) + max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) + + max_abs_diff = abs_diff.flatten()[max_abs_idx].item() + max_rel_diff = rel_diff.flatten()[max_rel_idx].item() + + raise AssertionError( + f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" + f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" + f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" + ) + +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + # seqlen q == k + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 2, 2), # small enough to debug + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 128, 32), # only one block + (3, 3, 3, 128, 128, 64), + (1, 1, 1, 127, 127, 32), # only one block but with masking + # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails + (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug + (1, 1, 1, 350, 350, 68), # generic masking on q, k and head + (4, 1, 1, 512, 512, 128), # batch > 1 + (4, 2, 2, 512, 512, 128), + (4, 2, 2, 512, 512, 68), + (4, 2, 2, 500, 500, 68), + (2, 4, 4, 1024, 1024, 64), + (4, 8, 8, 2048, 2048, 128), + (2, 8, 8, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # seqlen q > k + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 64, 32, 8), + (1, 1, 1, 128, 64, 16), + (1, 1, 1, 192, 128, 32), + (1, 2, 2, 1024, 512, 68), + (1, 4, 4, 729, 516, 68), + (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k + # seqlen q < k + (1, 1, 1, 2, 4, 16), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (1, 1, 1, 32, 64, 8), + (1, 1, 1, 128, 192, 32), + (4, 6, 6, 108, 256, 32), + (3, 2, 2, 256, 512, 16), + (2, 2, 2, 512, 1024, 68), + (1, 1, 1, 200, 413, 32), + (1, 1, 1, 782, 1546, 32), + # gqa/mqa # mismatch issue on varlen + (4, 8, 2, 500, 500, 68), + (4, 8, 2, 512, 512, 68), + (4, 8, 2, 512, 512, 128), + (4, 8, 2, 512, 1024, 68), + (4, 8, 2, 1024, 512, 64), + (4, 16, 4, 1528, 2753, 68), + # fa configs + (2, 4, 1, 113, 203, 64), + (2, 4, 2, 128, 217, 64), + (2, 6, 2, 113, 211, 128), + (2, 6, 2, 108, 256, 128), + (2, 6, 2, 256, 512, 64), + (2, 6, 2, 512, 256, 64), + (2, 6, 2, 1024, 1024, 32), + (2, 6, 2, 1023, 1024, 32), + (2, 6, 6, 1024, 1023, 32), + (2, 6, 6, 2048, 2048, 32), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('layout', ["bshd", "thd"]) +@pytest.mark.parametrize('packing', [None, "qkv"]) +@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.flaky(reruns=3, reason="Retry failures") +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): + torch.manual_seed(20) + test_backward = True + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + ref_dtype = torch.float32 + is_varlen = True if layout == "thd" else False + + # skip QKV packing tests for uneven sequence lengths and head sizes + if packing == 'qkv': + if N_CTX_Q != N_CTX_K: + pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") + if HQ != HK: + pytest.skip("QKV packing requires HQ == HK") + + # test apis + if packing == 'qkv': + # generate inputs + qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # ---------------------------------------------------------------- + # --- FP8 --- + # ---------------------------------------------------------------- + qkv_fp8 = qkv.clone() + do_fp8= do.clone() + + if is_varlen: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_qkvpacked_fp8_func( + qkv_fp8, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_qkvpacked_fp8_func( + qkv_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Reference --- + # ---------------------------------------------------------------- + # reference forward pass + qkv_ref = qkv.clone() + do_ref= do.clone() + + if is_varlen: + out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_qkvpacked_func( + qkv_ref, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_ref, lse_ref, S_dmask_ref = flash_attn_qkvpacked_func( + qkv_ref, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Compare --- + # ---------------------------------------------------------------- + # compare forward + if DEBUG: + print() + print(f"Compare fp8 against ref with dtype {ref_dtype}") -@pytest.mark.parametrize('batch_size, seqlen_q, seqlen_k, group_q, group_k, dim', get_input_shapes()) -def test_op_fwd_decode(batch_size, seqlen_q, seqlen_k, group_q, group_k, dim, dtype=torch.bfloat16): - if DEBUG: - print() - print(f"batch_size = {batch_size}, seqlen_q = {seqlen_q}, seqlen_k = {seqlen_k}, group_q = {group_q}, group_k = {group_k}, dim = {dim}") + if DEBUG: + print("out_ref:", out_ref, out_ref.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + + if DEBUG: + print("lse_ref:", lse_ref, lse_ref.shape) + print("lse_fp8:", lse_fp8, lse_fp8.shape) + fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + + if dropout_p > 0.0: + if DEBUG: + print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) + print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) + fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + if not test_backward: + return + + # fp8 backward pass + dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) + + # ref backward pass + dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) + + # compare backward gradients + if DEBUG: + print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) + print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) + fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + elif packing is None: + # generate inputs + q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device, DEBUG_INPUT=DEBUG_INPUT) + + # ---------------------------------------------------------------- + # --- FP8 --- + # ---------------------------------------------------------------- + if DEBUG: + print() + print(f"Compute Fp8 Forward") + q_fp8 = q.clone() + k_fp8 = k.clone() + v_fp8 = v.clone() + do_fp8= do.clone() + + if is_varlen: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_varlen_fp8_func( + q_fp8, + k_fp8, + v_fp8, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_fp8, lse_fp8, S_dmask_fp8 = flash_attn_fp8_func( + q_fp8, + k_fp8, + v_fp8, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Reference --- + # ---------------------------------------------------------------- + if DEBUG: + print() + print(f"Compute Reference Forward") + # reference forward pass + q_ref = q.clone() + k_ref = k.clone() + v_ref = v.clone() + do_ref = do.clone() + + if is_varlen: + out_ref, lse_ref, S_dmask_ref = flash_attn_varlen_func( + q_ref, + k_ref, + v_ref, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out_ref, lse_ref, S_dmask_ref = flash_attn_func( + q_ref, + k_ref, + v_ref, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # ---------------------------------------------------------------- + # --- Compare --- + # ---------------------------------------------------------------- + # compare forward + if DEBUG: + print() + print(f"Compare fp8 against ref with dtype {ref_dtype}") + + if DEBUG: + print("out_ref:", out_ref, out_ref.shape) + print("out_fp8:", out_fp8, out_fp8.shape) + # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + + if DEBUG: + print("lse_ref:", lse_ref, lse_ref.shape) + print("lse_fp8:", lse_fp8, lse_fp8.shape) + # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + + if dropout_p > 0.0: + if DEBUG: + print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) + print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) + # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) + + if not test_backward: + return + + if DEBUG: + print() + print(f"Compute Fp8 Backward") + # fp8 backward pass + dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) + + if DEBUG: + print() + print(f"Compute Reference Backward") + # ref backward pass + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) + + # compare backward gradients + if DEBUG: + print("dv_ref:", dv_ref, dv_ref.shape) + print("dv_fp8:", dv_fp8, dv_fp8.shape) + # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + if DEBUG: + print("dk_ref:", dk_ref, dk_ref.shape) + print("dk_fp8:", dk_fp8, dk_fp8.shape) + # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + + if DEBUG: + print("dq_ref:", dq_ref, dq_ref.shape) + print("dq_fp8:", dq_fp8, dq_fp8.shape) + # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) + fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) + +@pytest.mark.parametrize( + "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (2, 4, 4, 512, 512, 128), + ], +) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) +@pytest.mark.parametrize('layout', ['bshd']) +@pytest.mark.parametrize('packing', [None]) +@pytest.mark.parametrize('test_backward', [False, True]) +@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") +@pytest.mark.skip("Breaks on CI but works locally") +def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. torch.manual_seed(20) - query_group_head_size = (group_q + group_k - 1) // group_k - q = (torch.empty((batch_size, seqlen_q, group_k, query_group_head_size, dim), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_()) - k = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, - device="cuda").normal_(mean=0., - std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) - v = (torch.empty((batch_size, seqlen_k, group_k, 1, dim), dtype=dtype, - device="cuda").normal_(mean=0., - std=0.5).requires_grad_()).expand(-1, -1, -1, query_group_head_size, -1) - scale = 1 / dim**0.5 - input_metadata = MetaData(sm_scale=scale) - input_metadata.layout = "bsghd" - tri_out, _ = attention_decode(q, k, v, input_metadata) - - q = q.reshape([batch_size, seqlen_q, -1, dim]).permute(0, 2, 1, 3) - k = k.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) - v = v.reshape([batch_size, seqlen_k, -1, dim]).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) - ref_out = attn @ v - - # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0) - -def test_quantization(): - a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') - qa = quantize_kv_int4(a, num_groups=4) - dqa = dequantize_kv_fp16(qa, num_groups=4) - torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) - -@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', get_input_shapes()) -def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): - pytest.skip("Decode kernel doesnot support quantization yet") - torch.manual_seed(2) - q = (torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, - device="cuda").normal_(mean=1.0, std=0.5).requires_grad_()) - k = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, - device="cuda").normal_(mean=1.0, - std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) - v = (torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, - device="cuda").normal_(mean=1.0, - std=0.5).requires_grad_()).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) - - num_groups = 1 - quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) - quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) - scale = 1 / K**0.5 - input_metadata = MetaData(sm_scale=scale) - input_metadata.layout = "bsghd" - tri_out, _ = attention_decode(q, quant_k, quant_v, input_metadata) - - q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) - k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) - ref_out = attn @ v - # compare - torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) - - # since quantization introduces rounding error, use the - # dequantized kv as inputs to the ref implementation to reduce - # the tolerance to 1e-3 - dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) - dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) - dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) - dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) - dq_ref_out = dq_attn @ dqv - torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + device = "cuda" + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + ref_dtype = torch.float32 + is_varlen = True if layout == "thd" else False + + # remove cache + cache_path = Path(os.path.expanduser("~/.triton/cache")) + if cache_path.exists(): + shutil.rmtree(cache_path) + os.makedirs(cache_path) + + # inputs + q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) + + if packing == None: + # fp8 forward pass + if is_varlen: + out, lse, S_dmask = flash_attn_varlen_fp8_func( + q, + k, + v, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_fp8_func( + q, + k, + v, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # fp8 backward pass + if test_backward: + dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) + elif packing == "qkv": + # qkv packing path + # pack input tensors (use dim=1 for varlen, else dim=2) + if is_varlen: + qkv = torch.stack([q, k, v], dim=1) + else: + qkv = torch.stack([q, k, v], dim=2) + + # fp8 forward pass for qkv-packed input + if is_varlen: + out, lse, S_dmask = flash_attn_varlen_qkvpacked_fp8_func( + qkv, + metadata.cu_seqlens_q, + metadata.max_seqlens_q, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + else: + out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( + qkv, + dropout_p, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + + # fp8 backward pass for qkv-packed input + if test_backward: + dqkv, = torch.autograd.grad(out, (qkv,), do) + else: + raise ValueError(f"unknown packing type {packing}") + + # search for .ttir files + max_retries = 5 + retry_delay = 0.5 + ttir_files = [] + logging.info(f"Checking for .ttir files in {cache_path}...") + for attempt in range(max_retries): + # search for .ttir files recursively within the cache path + ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) + + if ttir_files: + # Files found, log success and exit the loop + logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") + break + else: + # Files not found yet + if attempt < max_retries - 1: + # If not the last attempt, wait and log before retrying + logging.warning( + f"No .ttir files found on attempt {attempt + 1}. " + f"Retrying in {retry_delay}s..." + ) + time.sleep(retry_delay) + else: + pytest.fail( + f"FATAL: No .ttir files found in cache {cache_path} " + f"after {max_retries} attempts." + ) + + # check if there is fp8 + ttir_files_fp8_found_status = {} + fp8_types = ['f8E4M3', 'f8E5M2'] + for ttir_file in ttir_files: + base_name = os.path.basename(ttir_file) + with open(ttir_file, 'r') as f: + content = f.read() + + # check content for fp8 + fp8_found = False + for f8_type in fp8_types: + if f8_type in content: + fp8_found = True + ttir_files_fp8_found_status[base_name] = fp8_found + + for file, fp8_found in ttir_files_fp8_found_status.items(): + assert fp8_found, f"{fp8_types} not found in {file}" diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py new file mode 100644 index 0000000000..fc5f5d0b1b --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/train.py @@ -0,0 +1,403 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset, random_split +import numpy as np +import pandas as pd +from tqdm import tqdm +import matplotlib.pyplot as plt +from datasets import load_dataset +from flash_attn import flash_attn_qkvpacked_func, flash_attn_qkvpacked_fp8_func, flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_fp8_func + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f"using device: {device}") + +# ------------------------------- +# Model +# ------------------------------- +class FlashAttention(nn.Module): + def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): + super().__init__() + self.use_fp8 = use_fp8 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.causal = causal + self.dropout_p = dropout + + # qkv and output projections + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + b, n, c = x.shape + # project to qkv + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + # reshape for flash attention function + qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) + + # use the appropriate flash attention function + if self.use_fp8: + context = flash_attn_qkvpacked_fp8_func( + qkv_packed, + dropout_p=self.dropout_p, + causal=self.causal + ) + else: + context = flash_attn_qkvpacked_func( + qkv_packed, + dropout_p=self.dropout_p, + causal=self.causal + ) + + # convert back to original shape and type + context = context.reshape(b, n, c) + + # output projection + x = self.proj(context) + + return x + +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) + + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + +class FlashLM(nn.Module): + def __init__( + self, + vocab_size, + dim=256, + depth=6, + num_heads=8, + mlp_ratio=4.0, + causal=True, + dropout=0.1, + max_seq_len=256, + use_fp8=False + ): + super().__init__() + + # embedding layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) + self.dropout = nn.Dropout(dropout) + + # transformer blocks + self.blocks = nn.ModuleList([ + TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) + for _ in range(depth) + ]) + + # lm head: project back to vocabulary dimension for each token + self.norm = nn.LayerNorm(dim) + self.lm_head = nn.Linear(dim, vocab_size) + + def forward(self, x): + b, n = x.shape + + # token + positional embedding + x = self.token_embedding(x) + x = x + self.position_embedding[:, :n, :] + x = self.dropout(x) + + # transformer blocks + for block in self.blocks: + x = block(x) + + # language modeling head + x = self.norm(x) + logits = self.lm_head(x) # shape: (b, n, vocab_size) + return logits + +# ------------------------------- +# Data +# ------------------------------- +class TextDataset(Dataset): + def __init__(self, sequences, max_len=None): + self.sequences = sequences + self.max_len = max_len + + def __len__(self): + return len(self.sequences) + + def __getitem__(self, idx): + seq = self.sequences[idx] + # input: all tokens except the last, target: all tokens except the first + return (torch.tensor(seq[:-1], dtype=torch.long), + torch.tensor(seq[1:], dtype=torch.long)) + +class VarLenTextDataset(Dataset): + def __init__(self, sequences, max_len=256): + self.sequences = sequences + self.max_len = max_len + + def __len__(self): + return len(self.sequences) + + def __getitem__(self, idx): + seq = self.sequences[idx] + # Ensure the sequence doesn't exceed max_len+1 + seq = seq[:self.max_len+1] + # input: all tokens except the last, target: all tokens except the first + return (torch.tensor(seq[:-1], dtype=torch.long), + torch.tensor(seq[1:], dtype=torch.long)) + +def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): + # load the WikiText-2 + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + + # build vocabulary + corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string + tokens = corpus.split() + vocab = sorted(set(tokens)) + word2idx = {word: idx for idx, word in enumerate(vocab)} + token_ids = [word2idx[word] for word in tokens] + + num_workers = 2 + if is_varlen: + # VARIABLE LENGTH: create sequences of different lengths + sequences = [] + for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences + # Decide target length for this sequence + if np.random.random() < ratio_shorter: + # Shorter sequence + target_len = np.random.randint(min_len + 1, max_len + 1) + else: + # Full length sequence + target_len = max_len + 1 + + # Extract sequence up to target length or whatever's available + seq_end = min(i + target_len, len(token_ids)) + seq = token_ids[i:seq_end] + + # Only keep sequences that are long enough + if len(seq) > min_len + 1: # +1 because we need both input and target + sequences.append(seq) + + print(f"Created {len(sequences)} variable-length sequences") + + # Get some statistics + lens = [len(seq) for seq in sequences] + print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") + + # split dataset + num_samples = len(sequences) + num_train = int(0.8 * num_samples) + num_val = num_samples - num_train + + # Use appropriate dataset class based on whether we need variable length + dataset_class = VarLenTextDataset + train_sequences = sequences[:num_train] + val_sequences = sequences[num_train:] + + train_dataset = dataset_class(train_sequences, max_len) + val_dataset = dataset_class(val_sequences, max_len) + + + # collate function + def collate_fn(batch): + """ + Collate function that creates a flat representation for variable length flash attention. + """ + # Separate inputs and targets + inputs, targets = zip(*batch) + + # Get sequence lengths + seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) + + # Concatenate inputs and targets into single tensors + flat_inputs = torch.cat(inputs) + flat_targets = torch.cat(targets) + + # Create cumulative sequence lengths tensor + cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) + + # Calculate max sequence length for this batch + max_seqlen = seq_lens.max().item() + + return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen + + # data loaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) + else: + # FIXED LENGTH: create sequences of length max_len+1 + sequences = [] + for i in range(0, len(token_ids) - max_len, max_len): + seq = token_ids[i : i + max_len + 1] + if len(seq) == max_len + 1: + sequences.append(seq) + + # split dataset + num_samples = len(sequences) + num_train = int(0.8 * num_samples) + num_val = num_samples - num_train + train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) + + # data loaders + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + vocab_size = len(vocab) + print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") + return train_dataloader, val_dataloader, vocab_size + +# ------------------------------- +# Training +# ------------------------------- +def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): + train_losses = [] + val_losses = [] + for epoch in range(num_epochs): + # Training phase + model.train() + epoch_train_loss = 0.0 + for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): + inputs, targets = inputs.to(device), targets.to(device) + + optimizer.zero_grad() + logits = model(inputs) + loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss.backward() + optimizer.step() + + epoch_train_loss += loss.item() + + epoch_train_loss /= len(train_dataloader) + train_losses.append(epoch_train_loss) + print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") + + # Validation phase + model.eval() + epoch_val_loss = 0.0 + with torch.no_grad(): + for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): + inputs, targets = inputs.to(device), targets.to(device) + logits = model(inputs) + loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) + epoch_val_loss += loss.item() + epoch_val_loss /= len(val_dataloader) + val_losses.append(epoch_val_loss) + print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") + + return train_losses, val_losses + +# ------------------------------- +# Main +# ------------------------------- +def main(): + # hyperparameters + batch_size = 16 + num_epochs = 20 + learning_rate = 3e-4 + max_len = 128 # total length including both input and target tokens + is_varlen = False + causal=True + dropout=0.1 + + # prep data + print("Preparing Dataset") + train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) + + # create language models + print("Creating Models") + model_normal = FlashLM( + vocab_size=vocab_size, + dim=256, + depth=3, + num_heads=8, + causal=causal, + dropout=dropout, + max_seq_len=max_len, + ).to(device) + + model_fp8 = FlashLM( + vocab_size=vocab_size, + dim=256, + depth=3, + num_heads=8, + causal=causal, + dropout=dropout, + max_seq_len=max_len, + use_fp8=True + ).to(device) + + # Train Normal model + print("Starting training for Normal model...") + optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) + criterion = nn.CrossEntropyLoss() + normal_train_losses, normal_val_losses = train_lm( + model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs + ) + torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') + print("Normal model training complete and saved.") + + # Train FP8 model + print("Starting training for FP8 model...") + optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) + fp8_train_losses, fp8_val_losses = train_lm( + model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs + ) + torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') + print("FP8 model training complete and saved.") + + # save losses to csv + epochs = range(1, num_epochs+1) + loss_data = { + "Epoch": epochs, + "Normal_Training_Loss": normal_train_losses, + "Normal_Validation_Loss": normal_val_losses, + "FP8_Training_Loss": fp8_train_losses, + "FP8_Validation_Loss": fp8_val_losses, + } + df_losses = pd.DataFrame(loss_data) + df_losses.to_csv("losses.csv", index=False) + print("Loss data saved to losses.csv") + + # plot Training Loss + plt.figure(figsize=(10, 6)) + plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') + plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') + plt.xlabel("Epoch") + plt.ylabel("Training Loss") + plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") + plt.legend() + plt.grid(True) + plt.savefig("training_loss.png") # Saves the training loss plot to disk + plt.show() + + # Plot Validation Loss + plt.figure(figsize=(10, 6)) + plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') + plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') + plt.xlabel("Epoch") + plt.ylabel("Validation Loss") + plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") + plt.legend() + plt.grid(True) + plt.savefig("validation_loss.png") # Saves the validation loss plot to disk + plt.show() + + +if __name__ == "__main__": + main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 530455063e..0300e3902a 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -1,32 +1,58 @@ - +import csv +import math import torch import os +import random +import functools import triton +import triton.language as tl +from typing import Literal, Optional, Union +# ------------------------------- +# Gloabl Variables +# ------------------------------- AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes') DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') +USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes') PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') - +USE_SINGLE_BWD_KERNEL = os.environ.get('USE_SINGLE_BWD_KERNEL', '0').lower() in ('1', 'true', 'yes') +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +USE_TRITON_INTERPRET = os.environ.get('TRITON_INTERPRET', '0').lower() in ('1', 'true', 'yes') +DEBUG_TRITON = os.environ.get('DEBUG_TRITON', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET +DEBUG_TRITON_DETAIL = os.environ.get('DEBUG_TRITON_DETAIL', '0').lower() in ('1', 'true', 'yes') and USE_TRITON_INTERPRET +if USE_TRITON_ROCM: # TODO remove this + random.seed(42) +DROPOUT_USE_PYTORCH = False +DROPOUT_DUMP = False + + +# ------------------------------- +# Metadata +# ------------------------------- class MetaData(): - cu_seqlens_q = None - cu_seqlens_k = None - max_seqlens_q = 0 - max_seqlens_k = 0 - bias = None - alibi_slopes = None - causal = False + cu_seqlens_q: Optional[torch.Tensor] = None + cu_seqlens_k: Optional[torch.Tensor] = None + max_seqlens_q: int = 0 + max_seqlens_k: int = 0 + bias: Optional[torch.Tensor] = None + alibi_slopes: Optional[torch.Tensor] = None + causal: bool = False num_contexts = 0 - varlen = False - layout = None - cache_seqlens = None + varlen: bool = False + layout: Optional[Literal["bshd", "bhsd", "thd"]] = None + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None cache_batch_idx = None - new_kv = False - seqlen_new = None - k_new = None - v_new = None - dropout_p, return_scores= 0.0, False + packing: Optional[bool] = None + return_scores: bool = False + dropout_p: float = 0.0 + philox_seed: Optional[int] = None + philox_offset : Optional[int]= None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing. # NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW. - use_exp2 = False + use_exp2: bool = False + rotary_sin: Optional[torch.Tensor] = None + rotary_cos: Optional[torch.Tensor] = None + rotary_interleaved: bool = False + rotary_conjunction: bool = False def __repr__(self) -> str: @@ -44,10 +70,6 @@ def __repr__(self) -> str: f" layout={self.layout},\n" f" cache_seqlens={self.cache_seqlens},\n" f" cache_batch_idx={self.cache_batch_idx},\n" - f" new_kv={self.new_kv},\n" - f" seqlen_new={self.seqlen_new},\n" - f" k_new={self.k_new},\n" - f" v_new={self.v_new},\n" f" dropout_p={self.dropout_p},\n" f" return_scores={self.return_scores}\n" f")") @@ -55,18 +77,17 @@ def __repr__(self) -> str: def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale - def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k): self.varlen = True self.layout = 'thd' self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k + self.max_seqlens_q = max_seqlen_q + self.max_seqlens_k = max_seqlen_k + # Without "varlen", there should still be one sequence. assert len(cu_seqlens_q) >= 2 assert len(cu_seqlens_q) == len(cu_seqlens_k) - self.num_contexts = len(cu_seqlens_q) - 1 - for i in range(0, self.num_contexts): - self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) - self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): assert bias.is_cuda @@ -82,17 +103,25 @@ def need_alibi(self, alibi_slopes, batch, nheads): assert alibi_slopes.shape[1] == nheads self.alibi_slopes = alibi_slopes - def need_causal(self): - self.causal = True + def need_causal(self, causal): + self.causal = causal + + def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False): + self.rotary_sin = sin + self.rotary_cos = cos + self.rotary_interleaved = rotary_interleaved + self.rotary_conjunction = rotary_conjunction - def need_dropout(self, dropout_p, return_scores): - self.dropout_p = dropout_p - self.return_scores = return_scores + def need_dropout(self, dropout_p, return_scores = True): + if dropout_p > 0.0: + self.dropout_p = dropout_p + self.return_scores = return_scores + self.philox_seed, self.philox_offset = 0x1BF58, 0x1D4B49 def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() - batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) + batch, nheads_q, nheads_k, head_size, _, _ = get_shapes_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) if self.varlen: assert q.dim() == 3 assert self.cu_seqlens_q is not None @@ -100,8 +129,6 @@ def check_args(self, q, k, v, o): assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) # TODO: Remove once bias is supported with varlen assert self.bias is None - # TODO:Remove once dropout is supported with varlen - assert self.dropout_p == 0.0 # assert not self.return_scores else: assert q.dim() == 4 @@ -111,131 +138,545 @@ def check_args(self, q, k, v, o): assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen -def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, device="cuda", DEBUG_INPUT=False): - torch.manual_seed(20) +# ------------------------------- +# Input Helper +# ------------------------------- +def random_seqlens_composition(SEQ_LEN, BATCH): + # generate a random composition of N into Z positive parts. + idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 + idx, _ = torch.sort(idx) + breakpoints = torch.cat([ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([SEQ_LEN], dtype=torch.long), + ]) + seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) + return seqlens + +def generate_varlen_tensor( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float32, + DEBUG_INPUT: bool = False +): + if DEBUG: + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), + total_seqlen // batch_size, + dtype=torch.int32, + device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) - # Initialize q, k, v - if layout == 'bhsd': - q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) - k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == 'bshd': - q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) - k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + # create cumulative sequence lengths + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)]).to(torch.int32).to(device=device) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen tensor + if DEBUG_INPUT: + x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + length = end - start + + x[start:end, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1) + .expand(length, num_heads, head_size) + ) else: - assert False, f'Got unsupported tensor layout: {layout}' + x = torch.randn((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + x.requires_grad_() + return x, cu_seqlens, max_seqlen, descale_x + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + +def generate_bshd_tensor(BATCH, SEQ_LEN, NUM_HEADS, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) if DEBUG_INPUT: - if layout == "bhsd": - q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, 1, N_CTX_Q, 1).expand(*q_tensor_shape).contiguous().requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, 1, N_CTX_K, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - elif layout == "bshd": - q = torch.arange(N_CTX_Q, dtype=dtype, device=device).view(1, N_CTX_Q, 1, 1).expand(*q_tensor_shape).contiguous().requires_grad_() - k = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() - v = torch.arange(N_CTX_K, dtype=dtype, device=device).view(1, N_CTX_K, 1, 1).expand(*k_tensor_shape).contiguous().requires_grad_() + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, SEQ_LEN, 1, 1).expand(*tensor_shape).contiguous() + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") + x.requires_grad_() + return x, descale_x else: - q = torch.randn(q_tensor_shape, dtype=dtype, device=device, requires_grad=True) - k = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) - v = torch.randn(k_tensor_shape, dtype=dtype, device=device, requires_grad=True) + x.requires_grad_() + return x + +def generate_bhsd_tensor(BATCH, NUM_HEADS, SEQ_LEN, D_HEAD, dtype, device="cuda", DEBUG_INPUT=False): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + # gen tensor + tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) if DEBUG_INPUT: - sm_scale = 1 + x = torch.arange(SEQ_LEN, dtype=dtype, device=device).view(1, 1, SEQ_LEN, 1).expand(*tensor_shape).contiguous() else: - sm_scale = D_HEAD**-0.5 - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.max_seqlens_q = N_CTX_Q - input_metadata.max_seqlens_k = N_CTX_K - input_metadata.layout = layout - return q, k, v, input_metadata - + x = torch.randn(tensor_shape, dtype=dtype, device=device) + -def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device="cuda", equal_seqlens=False, DEBUG_INPUT=False): + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bhsd") # FIXME: I don't the casting fn supports this atm + x.requires_grad_() + return x, descale_x + else: + x.requires_grad_() + return x + +def input_helper( + BATCH: int, + HQ: int, + HK: int, + N_CTX_Q: int, + N_CTX_K: int, + D_HEAD: int, + CAUSAL: bool, + DROPOUT_P: float, + dtype: torch.dtype, + layout: Literal["bshd", "bhsd", "thd"], + packing: Optional[Literal["kv", "qkv"]] = None, + device: Literal["cpu", "cuda"] = "cuda", + DEBUG_INPUT: bool = False, +): torch.manual_seed(20) - - # Random or equal sequence lengths based on 'equal_seqlens' flag - if not equal_seqlens: - max_seqlens_q = N_CTX_Q // Z - max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) - seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + is_fp8_dtype = is_dtype_fp8(dtype) + + if layout == "thd": + # set params + TOTAL_SEQLENS_Q = BATCH * N_CTX_Q + TOTAL_SEQLENS_K = BATCH * N_CTX_K + equal_seqlens=False + + # gen tensors + # TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen + if is_fp8_dtype: + q, cu_seqlens_q, max_seqlen_q, descale_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k, descale_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _ , descale_v = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do, _, _ , descale_do = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens) + else: + q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + + # setup metadata + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + metadata = MetaData(sm_scale=sm_scale) + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + metadata.need_causal(CAUSAL) + metadata.need_dropout(DROPOUT_P) + elif layout == 'bshd' or layout == "bhsd": + # gen tensors + if layout == "bshd": + if is_fp8_dtype: + q, descale_q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + elif layout == "bhsd": + if is_fp8_dtype: + q, descale_q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k, descale_k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v, descale_v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do, descale_do = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device) + else: + q = generate_bhsd_tensor(BATCH, HQ, N_CTX_Q, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + k = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + v = generate_bhsd_tensor(BATCH, HK, N_CTX_K, D_HEAD, dtype=dtype, device=device, DEBUG_INPUT=DEBUG_INPUT) + do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q) + + # setup metadata + if DEBUG_INPUT: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 + metadata = MetaData(sm_scale=sm_scale) + metadata.max_seqlens_q = N_CTX_Q + metadata.max_seqlens_k = N_CTX_K + metadata.layout = layout + metadata.need_causal(CAUSAL) + metadata.need_dropout(DROPOUT_P) else: - seqlens_q = torch.full((Z,), N_CTX_Q // Z, dtype=torch.int32) - seqlens_k = torch.full((Z,), N_CTX_K // Z, dtype=torch.int32) + raise ValueError(f"Unknown layout: {layout}") - # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0)]) - cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0)]) - cu_seqlens_q = cu_seqlens_q.to(device=device).to(torch.int32) - cu_seqlens_k = cu_seqlens_k.to(device=device).to(torch.int32) + # deal with packing + if packing is None: + if is_fp8_dtype: + return (q, descale_q), (k, descale_k), (v, descale_v), (do, descale_do), metadata + else: + return q, k, v, do, metadata + elif packing == "kv": + # pack k and v + if layout in ["bhsd", "thd"]: + kv = torch.stack([k, v], dim=1) + elif layout == "bshd": + kv = torch.stack([k, v], dim=2) + else: + raise ValueError(f"Unknown layout: {layout}") - # Total lengths - total_q = cu_seqlens_q[-1].item() - total_k = cu_seqlens_k[-1].item() + if is_fp8_dtype: + raise ValueError("FP8 not supported kv packing yet") + else: + return q, kv, do, metadata + elif packing == "qkv": + # qkv packing - requires same sequence length for q and k + assert N_CTX_Q == N_CTX_K, "For QKV packing, Q and K must have same sequence length" + assert HQ == HK, "For QKV packing, Q and K must have same number of heads" + + # pack q, k, and v + if layout in ["bhsd", "thd"]: + qkv = torch.stack([q, k, v], dim=1) + elif layout == "bshd": + qkv = torch.stack([q, k, v], dim=2) + else: + raise ValueError(f"Unknown layout: {layout}") - if DEBUG_INPUT: - # Initialize q, k, v with deterministic values - q = torch.arange(total_q, dtype=dtype, device=device).view(total_q, 1, 1) - q = q.expand(total_q, HQ, D_HEAD).contiguous().requires_grad_() - k = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - k = k.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - v = torch.arange(total_k, dtype=dtype, device=device).view(total_k, 1, 1) - v = v.expand(total_k, HK, D_HEAD).contiguous().requires_grad_() - sm_scale = 1 + if is_fp8_dtype: + raise ValueError("FP8 not supported qkv packing yet") + else: + return qkv, do, metadata else: - # Initialize q, k, v with random values - q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device=device).requires_grad_() - k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() - v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device=device).requires_grad_() - sm_scale = D_HEAD ** -0.5 - - input_metadata = MetaData(sm_scale=sm_scale) - input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) - return q, k, v, input_metadata - - -def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + assert False, f"Unsupported packing mode: {packing}" + +# ------------------------------- +# Alibi +# ------------------------------- +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + +# ------------------------------- +# FP8 +# ------------------------------- +def is_dtype_fp8(dtype): + if dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}: + if arch_supports_fp8(): + return True + else: + raise RuntimeError("This device doesnot support fp8") + else: + return False + +def is_fp8(x): + return is_dtype_fp8(x.dtype) + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + # compute fp8 scaling and descaling factor for a block + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + +@triton.jit +def _cast_varlen_to_fp8_kernel_2d( + X, X_fp8, Descale, + cu_seqlens, H, MAX_SEQLEN, + stride_batch, stride_seq, stride_head, stride_dim, + stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, + stride_desc_batch, stride_desc_head, + FP8_CLAMP_VAL, + FP8_MAX, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr + ): + # Process one (batch, head) pair per kernel + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + # Get sequence bounds for this batch + if IS_VARLEN: + seq_start = tl.load(cu_seqlens + b_id) + seq_end = tl.load(cu_seqlens + b_id + 1) + seqlen = seq_end - seq_start + else: + seq_start = 0 + seqlen = MAX_SEQLEN + + # initialize max value tracker + x_max_val = 0.0 + + # STEP 1: Find max absolute value across the entire sequence + num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) + for blk_idx in range(0, num_of_blocks): + # print("blk_idx:", blk_idx) + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block + adj_x = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) + # print("x_block:", x_block) + + # Find max absolute value in this block + block_max = tl.max(tl.abs(x_block)) + # print("block_max:", block_max) + + # Update overall max + x_max_val = tl.maximum(x_max_val, block_max) + # print("x_max_val:", x_max_val) + + # clamp to avoid division by zero issues + x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) + + # compute scale and descale factors for the entire sequence + scale = FP8_MAX / x_max_val + descale = x_max_val / FP8_MAX + + # store descale factor for this (batch, head) pair + desc_ptr = Descale + b_id * stride_desc_batch + h_id# * stride_desc_head + tl.store(desc_ptr, descale) + + # STEP 2: Apply scaling to the entire sequence and convert to FP8 + for blk_idx in range(0, num_of_blocks): + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block - Using the fixed addressing + addr = b_id * stride_batch + h_id * stride_head + seq_start * stride_seq + offs_seq[:, None] * stride_seq + offs_dim[None, :] * stride_dim + x_block = tl.load(X + addr, mask=mask_seq, other=0.0) + + # Apply scale and convert to FP8 + x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) + + # Store results + addr_out = b_id * stride_out_batch + h_id * stride_out_head + seq_start * stride_out_seq + offs_seq[:, None] * stride_out_seq + offs_dim[None, :] * stride_out_dim + tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + layout: Literal["bshd", "thd"], + clamp_val: float = 1e-9, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None +) -> tuple[torch.Tensor, torch.Tensor]: + if False: + print() + print("cast_to_fp8") + print("x:", x, x.shape) + print("fp8_dtype:", fp8_dtype) + print("cu_seqlens:", cu_seqlens) + print("max_seqlen:", max_seqlen) + print("clamp_val:", clamp_val) + + # check types are valid + assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + + # extract dimensions + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout(x, layout, cu_seqlens, max_seqlen) + is_varlen = layout == "thd" + fp8_max = torch.finfo(fp8_dtype).max + if False: + print("batch:", batch) + print("max_seqlen_final:", max_seqlen_final) + print("num_heads:", num_heads) + print("head_dim:", head_dim) + + # get closest power of 2 for head_dim + padded_head_dim = 1 << (head_dim - 1).bit_length() + padded_head_dim = max(padded_head_dim, 32) + + # kernel params + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros((batch, num_heads), device=x.device, dtype=torch.float32) + BLOCK_SIZE = 128 + + # calculate strides + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout(x, layout) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = get_stride_from_layout(x_fp8, layout) + stride_desc_batch, stride_desc_head = descale_factors.stride() + + if False: + print("stride_batch", stride_batch) + print("stride_head", stride_head) + print("stride_seq", stride_seq) + print("stride_dim", stride_dim) + print("stride_out_batch", stride_out_batch) + print("stride_out_head", stride_out_head) + print("stride_out_seq", stride_out_seq) + print("stride_out_dim", stride_out_dim) + print("stride_desc_batch", stride_desc_batch) + print("stride_desc_head", stride_desc_head) + + grid = (batch, num_heads) + _cast_varlen_to_fp8_kernel_2d[grid]( + x, x_fp8, descale_factors, + cu_seqlens, num_heads, max_seqlen_final, + stride_batch, stride_seq, stride_head, stride_dim, + stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, + stride_desc_batch, stride_desc_head, + clamp_val, fp8_max, + BLOCK_SIZE=BLOCK_SIZE, + HEAD_DIM=padded_head_dim, + ACTUAL_HEAD_DIM=head_dim, + IS_VARLEN=is_varlen + ) + + if False: + print("x_fp8:", x_fp8, x_fp8.shape) + print("descale_factors:", descale_factors, descale_factors.shape) + return x_fp8, descale_factors + +# ------------------------------- +# Misc +# ------------------------------- +def get_shape_from_layout( + x: torch.Tensor, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[int, int, int, int]: if layout == 'bhsd': - batch_q, nheads_q, max_seqlen_q, head_size_q = q.shape - batch_k, nheads_k, max_seqlen_k, head_size_k = k.shape + batch, num_heads, max_seqlen_final, head_dim = x.shape elif layout == 'bshd': - batch_q, max_seqlen_q, nheads_q, head_size_q = q.shape - batch_k, max_seqlen_k, nheads_k, head_size_k = k.shape + batch, max_seqlen_final, num_heads, head_dim = x.shape elif layout == 'thd': - batch_q, max_seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] - batch_k, max_seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] + total_seqlen, num_heads, head_dim = x.shape + if cu_seqlens is None: + raise ValueError("cu_seqlens must be provided for varlen (thd) layout") + if max_seqlen is None: + raise ValueError("max_seqlen must be provided for varlen (thd) layout") + + batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim else: assert False, "Got unsupported layout." + + return batch, max_seqlen_final, num_heads, head_dim + + +def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): + batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q) + batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k) # assert assert batch_q == batch_k assert head_size_q == head_size_k - return batch_q, nheads_q, nheads_k, head_size_q, max_seqlen_q, max_seqlen_k + return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k -def get_strides_from_layout(q, k, v, o, layout): +def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]): if layout == 'thd': - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + strides = (0, x.stride(1), x.stride(0), x.stride(2)) elif layout == 'bhsd': - q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) - k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) - v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) - o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) elif layout == 'bshd': - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) else: assert False, 'Got unsupported layout.' + return strides + +def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None): + return get_shape_from_layout(x, layout, cu_seqlens, max_seqlen), get_stride_from_layout(x, layout) + +def get_strides_from_layout(q, k, v, o, layout): + q_strides = get_stride_from_layout(q, layout) + k_strides = get_stride_from_layout(k, layout) + v_strides = get_stride_from_layout(v, layout) + o_strides = get_stride_from_layout(o, layout) return q_strides, k_strides, v_strides, o_strides def get_padded_headsize(size): @@ -246,29 +687,90 @@ def get_padded_headsize(size): padded_d_model = max(padded_d_model, 16) return padded_d_model - -def _strides(x: torch.Tensor, *stride_names: str): - if x is None: - return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} - - assert x.ndim == len(stride_names) - return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} - -def get_input_shapes(): - cases = [(max(1, 2**(16 - i)), 1, 2**i, 16, 1, 128) - for i in range(8, 18)] + [(max(1, 2**(16 - i)), 1, 2**i, 16, 2, 128) for i in range(8, 18)] - return cases - +def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + +# ------------------------------- +# Dropouts +# ------------------------------- +def create_dropout_mask(dropout_p, shape, seed): + device = "cuda" + rand_vals = torch.rand(shape, generator=torch.Generator(device=device).manual_seed(seed), device=device, dtype=torch.float32) + return rand_vals > dropout_p + +def create_dropout_mask_varlen(dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed): + device = "cuda" + qlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) + klens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) + max_qlen = qlens.max() + max_klen = klens.max() + dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) + for b in range(batch): + qlen = qlens[b] + klen = klens[b] + rand_vals = torch.rand((nheads_q, qlen, klen), generator=torch.Generator(device=device).manual_seed(philox_seed), device=device, dtype=torch.float32) + submask = rand_vals > dropout_p + dropout_mask[b, :, :qlen, :klen] = submask + + return dropout_mask + +def write_dropout_mask(x, tensor_name = "tensor"): + batch, head, seqlen_m, seqlen_n = x.shape + x = x.tolist() + + with open(f'{tensor_name}.csv', 'w') as f: + writer = csv.writer(f) + for b in range(batch): + for h in range(head): + dropout_mask = x[b][h] + if True: + BLOCK_M = 64 + BLOCK_N = 64 + + # Calculate number of blocks in each dimension + m_blocks = math.ceil(seqlen_m / BLOCK_M) + n_blocks = math.ceil(seqlen_n / BLOCK_N) + + # Process each block + for m_block in range(m_blocks): + # Calculate row range for current block + row_start = m_block * BLOCK_M + row_end = min(row_start + BLOCK_M, seqlen_m) + + for n_block in range(n_blocks): + # Calculate column range for current block + col_start = n_block * BLOCK_N + col_end = min(col_start + BLOCK_N, seqlen_n) + + # Extract and write the current block + for row_idx in range(row_start, row_end): + row_data = dropout_mask[row_idx][col_start:col_end] + writer.writerow(row_data) + else: + writer.writerows(dropout_mask) + +# ------------------------------- +# Runtime info +# ------------------------------- +@functools.cache def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" +@functools.cache +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch +@functools.cache def is_cdna(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', - 'gfx90a', 'gfx908') - + return is_hip() and get_arch() in ('gfx908', 'gfx90a', 'gfx940', 'gfx941', 'gfx942', 'gfx950') +@functools.cache def is_rdna(): - return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", - "gfx1102", "gfx1200", "gfx1201") + return is_hip() and get_arch() in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") +@functools.cache +def arch_supports_fp8(): + return is_hip() and get_arch() in ('gfx942') diff --git a/flash_attn/fused_softmax.py b/flash_attn/fused_softmax.py deleted file mode 100644 index 382f94f092..0000000000 --- a/flash_attn/fused_softmax.py +++ /dev/null @@ -1,201 +0,0 @@ -# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py -# for benchmarking. -# We added support for seqlen=2k and seqlen=4k - -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType -from fused_softmax_lib import ( - scaled_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_get_batch_per_block, - scaled_upper_triang_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, -) - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 8192 # sk must be 16 ~ 8192 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 8192: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 6d021f8391..72e43337ad 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -1,9 +1,12 @@ -# Copyright (c) 2023, Tri Dao. +# Copyright (c) 2025, Tri Dao import math +from functools import partial from typing import Optional, Tuple, Union import torch +from torch import Tensor + from einops import rearrange, repeat from flash_attn.ops.triton.rotary import apply_rotary @@ -41,8 +44,8 @@ def forward( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, + seqlen_offsets: Union[int, Tensor] = 0, + cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): out = apply_rotary( @@ -73,10 +76,6 @@ def backward(ctx, do): cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: cos, sin, cu_seqlens = ctx.saved_tensors - # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with - # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. - if not ctx.interleaved and not ctx.inplace: - do = do.clone() dx = apply_rotary( do, cos, @@ -97,8 +96,8 @@ def apply_rotary_emb( sin, interleaved=False, inplace=False, - seqlen_offsets: Union[int, torch.Tensor] = 0, - cu_seqlens: Optional[torch.Tensor] = None, + seqlen_offsets: Union[int, Tensor] = 0, + cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): """ @@ -128,6 +127,70 @@ def apply_rotary_emb( apply_rotary_emb_func = apply_rotary_emb +def _apply_rotary_emb_qkv( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + inplace=False, + conjugate=False, + seqlen_offsets: Union[int, Tensor] = 0, + num_heads_q: Optional[int] = None, +): + apply_rotary_fn = partial( + apply_rotary, + interleaved=interleaved, + inplace=inplace, + conjugate=conjugate, + seqlen_offsets=seqlen_offsets + ) + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + if qkv.dim() == 5: + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + qk = apply_rotary_fn(qk, cos, sin) + else: + assert qkv.dim() == 4 + assert num_heads_q is not None + num_heads_k = (qkv.shape[2] - num_heads_q) // 2 + assert qkv.shape[2] == num_heads_q + 2 * num_heads_k + qk = qkv[:, :, :num_heads_q + num_heads_k] + qk = apply_rotary_fn(qk, cos, sin) + if not inplace: + if qkv.dim() == 5: + qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) + else: + qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + if qkv.dim() == 5: + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + q, k = qkv[:, :, 0], qkv[:, :, 1] + else: + assert qkv.dim() == 4 + assert num_heads_q is not None + num_heads_k = (qkv.shape[2] - num_heads_q) // 2 + assert qkv.shape[2] == num_heads_q + 2 * num_heads_k + q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] + q = apply_rotary_fn(q, cos, sin) + k = apply_rotary_fn(k, cos_k, sin_k) + if not inplace: + if qkv.dim() == 5: + qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) + else: + qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) + return qkv + + class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod def forward( @@ -139,40 +202,13 @@ def forward( sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, - num_heads_q: Union[int] = None, + num_heads_q: Optional[int] = None, ): - if cos_k is None and sin_k is None and qkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need qkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if qkv.dim() == 5: - batch, seqlen, three, nheads, headdim = qkv.shape - assert three == 3 - # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") - qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - qk = qkv[:, :, :num_heads_q + num_heads_k] - apply_rotary( - qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if qkv.dim() == 5: - q, k = qkv[:, :, 0], qkv[:, :, 1] - else: - assert qkv.dim() == 4 - assert num_heads_q is not None - num_heads_k = (qkv.shape[2] - num_heads_q) // 2 - assert qkv.shape[2] == num_heads_q + 2 * num_heads_k - q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] - apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) - apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) - ctx.save_for_backward(cos, sin, cos_k, sin_k) + # apply_rotary_emb_qkv_inplace( + qkv = _apply_rotary_emb_qkv( + qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, + seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, + ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cos_k, sin_k) ctx.seqlen_offsets = seqlen_offsets @@ -190,57 +226,10 @@ def backward(ctx, dqkv): cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors else: cos, sin, cos_k, sin_k = ctx.saved_tensors - if cos_k is None and sin_k is None and dqkv.is_contiguous(): - # Call 1 kernel instead of 2 kernels - # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) - # dimensions, we get the same tensor - if dqkv.dim() == 5: - dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") - else: - assert dqkv.dim() == 4 - assert ctx.num_heads_q is not None - num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2 - assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k - dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k] - apply_rotary( - dqk, - cos, - sin, - seqlen_offsets=seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - else: - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - if dqkv.dim() == 5: - dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] - else: - assert dqkv.dim() == 4 - assert ctx.num_heads_q is not None - num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2 - assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k - dq = dqkv[:, :, : ctx.num_heads_q] - dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k] - apply_rotary( - dq, - cos, - sin, - seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) - apply_rotary( - dk, - cos_k, - sin_k, - seqlen_offsets, - interleaved=ctx.interleaved, - inplace=True, - conjugate=True, - ) + dqkv = _apply_rotary_emb_qkv( + dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, + seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, + ) return dqkv, None, None, None, None, None, None, None @@ -276,6 +265,7 @@ def apply_rotary_emb_qkv_( class ApplyRotaryEmbKV_(torch.autograd.Function): + @staticmethod def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): batch, seqlen, two, nheads, headdim = kv.shape @@ -362,27 +352,15 @@ def __init__( base=10000.0, interleaved=False, scale_base=None, - pos_idx_in_fp32=True, device=None, ): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. """ super().__init__() self.dim = dim self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -421,21 +399,16 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): self._seq_len_cached = seqlen # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP + # Don't do einsum, it converts fp32 to bf16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, inv_freq) if self.scale is None: @@ -479,26 +452,16 @@ def forward( elif isinstance(seqlen_offset, int): self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: - if self.scale is None: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached, - self._sin_cached, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - num_heads_q=num_heads_q, - ) + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached if self.scale is not None else None, + self._sin_k_cached if self.scale is not None else None, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + num_heads_q=num_heads_q, + ) else: q = qkv q = apply_rotary_emb_func( @@ -509,20 +472,11 @@ def forward( inplace=True, seqlen_offsets=seqlen_offset, ) - if self.scale is None: - kv = apply_rotary_emb_kv_( - kv, - self._cos_cached, - self._sin_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) - else: - kv = apply_rotary_emb_kv_( - kv, - self._cos_k_cached, - self._sin_k_cached, - interleaved=self.interleaved, - seqlen_offsets=seqlen_offset, - ) + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached if self.scale is None else self._cos_k_cached, + self._sin_cached if self.scale is None else self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) return q, kv diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 77640c2b23..b2a7f22d24 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -23,9 +23,9 @@ flash_attn_with_kvcache = None try: - from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: - FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + ColumnParallelLinear, RowParallelLinear = None, None try: from flash_attn.layers.rotary import RotaryEmbedding @@ -341,13 +341,6 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): return output -class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input), input - - def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. @@ -452,13 +445,6 @@ def __init__( device=device, ) - if fused_bias_fc and FusedDense is None: - raise ImportError("fused_dense is not installed") - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - linear_resid_cls = ( - LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) - ) - wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn @@ -470,10 +456,10 @@ def __init__( else CrossAttention ) if not self.cross_attn: - self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: - self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) - self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) if self.dwconv: if self.num_heads_kv == self.num_heads: self.dwconv_qkv = nn.Conv1d( @@ -492,7 +478,7 @@ def __init__( self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) - self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype @@ -646,10 +632,7 @@ def forward( batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" @@ -680,21 +663,11 @@ def forward( ) else: if self.cross_attn: - if not self.return_residual: - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) - kv = self.Wkv(x_kv if x_kv is not None else x) - else: - if x_kv is not None: - kv, x_kv = self.Wkv(x_kv) - else: - kv, x = self.Wkv(x) - q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) else: assert self.num_heads_kv != self.num_heads - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) + qkv = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 1e45b8e609..6b4033d134 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -11,9 +11,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.utils.distributed import ( all_gather_raw, diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 7b0315b979..1b5a415b73 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -166,6 +166,7 @@ def forward( if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: labels = F.pad(labels, (0, 1))[..., :-1] assert labels.data_ptr() % 16 == 0 + assert logit_scale > 0.0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) diff --git a/flash_attn/ops/triton/layer_norm.py b/flash_attn/ops/triton/layer_norm.py index addffe1f18..192cee474b 100644 --- a/flash_attn/ops/triton/layer_norm.py +++ b/flash_attn/ops/triton/layer_norm.py @@ -7,14 +7,39 @@ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. import math +from typing import Optional, List import torch import torch.nn.functional as F -from torch.cuda.amp import custom_fwd, custom_bwd +from torch import Tensor import triton import triton.language as tl +from flash_attn.utils.torch import custom_fwd, custom_bwd +from flash_attn.utils.library import triton_op + + +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] + if warp_count * warp_size <= max_threads_per_block] + # return [triton.Config({}, num_warps=8)] + def layer_norm_ref( x, @@ -28,6 +53,7 @@ def layer_norm_ref( dropout_p=0.0, rowscale=None, prenorm=False, + zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, @@ -41,6 +67,10 @@ def layer_norm_ref( x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: @@ -83,6 +113,7 @@ def rms_norm_ref( dropout_p=0.0, rowscale=None, prenorm=False, + zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, @@ -96,6 +127,10 @@ def rms_norm_ref( x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: @@ -126,21 +161,15 @@ def rms_norm_ref( @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], - key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], ) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) -@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) -@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input @@ -156,6 +185,7 @@ def _layer_norm_fwd_1pass_kernel( ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, + DROPOUT_MASK1, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row @@ -168,6 +198,7 @@ def _layer_norm_fwd_1pass_kernel( N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, @@ -218,7 +249,7 @@ def _layer_norm_fwd_1pass_kernel( ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: - tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) @@ -238,6 +269,8 @@ def _layer_norm_fwd_1pass_kernel( # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd @@ -246,6 +279,8 @@ def _layer_norm_fwd_1pass_kernel( tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 @@ -253,25 +288,87 @@ def _layer_norm_fwd_1pass_kernel( def _layer_norm_fwd( - x, - weight, - bias, - eps, - residual=None, - x1=None, - weight1=None, - bias1=None, - dropout_p=0.0, - rowscale=None, - out_dtype=None, - residual_dtype=None, - is_rms_norm=False, - return_dropout_mask=False, - out=None, - residual_out=None -): + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) if residual is not None: residual_dtype = residual.dtype + if residual_out is None and ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty_like( + x, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +@triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, + schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") +def _layer_norm_fwd_impl( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 if residual is not None: @@ -295,33 +392,16 @@ def _layer_norm_fwd( if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) - # allocate output - if out is None: - out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) - else: - assert out.shape == x.shape + assert out.shape == x.shape assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None - if ( - residual is not None - or (residual_dtype is not None and residual_dtype != x.dtype) - or dropout_p > 0.0 - or rowscale is not None - or x1 is not None - ): - if residual_out is None: - residual_out = torch.empty( - M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype - ) - else: - assert residual_out.shape == x.shape - assert residual_out.stride(-1) == 1 - else: - residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: @@ -331,16 +411,20 @@ def _layer_norm_fwd( else: seeds = None if return_dropout_mask and dropout_p > 0.0: - dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None else: - dropout_mask = None + dropout_mask, dropout_mask1 = None, None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[(M,)]( + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( x, out, weight, @@ -354,6 +438,7 @@ def _layer_norm_fwd( rowscale, seeds, dropout_mask, + dropout_mask1, mean, rstd, x.stride(0), @@ -366,6 +451,8 @@ def _layer_norm_fwd( N, eps, dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), is_rms_norm, BLOCK_N, residual is not None, @@ -374,43 +461,26 @@ def _layer_norm_fwd( dropout_p > 0.0, dropout_mask is not None, rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, ) - # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 - if dropout_mask is not None and x1 is not None: - dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) - else: - dropout_mask1 = None - return ( - out, - y1, - mean, - rstd, - residual_out if residual_out is not None else x, - seeds, - dropout_mask, - dropout_mask1, - ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 @triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - triton.Config({}, num_warps=8), - triton.Config({}, num_warps=16), - triton.Config({}, num_warps=32), - ], + configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) -@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) -@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) -@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) -@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) -@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +# @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +# @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +# @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input @@ -444,6 +514,7 @@ def _layer_norm_bwd_kernel( N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, + zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, @@ -477,10 +548,14 @@ def _layer_norm_bwd_kernel( if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) @@ -566,31 +641,93 @@ def _layer_norm_bwd_kernel( def _layer_norm_bwd( - dy, - x, - weight, - bias, - eps, - mean, - rstd, - dresidual=None, - dy1=None, - weight1=None, - bias1=None, - seeds=None, - dropout_p=0.0, - rowscale=None, - has_residual=False, - has_x1=False, - is_rms_norm=False, - x_dtype=None, - recompute_output=False, -): + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, + # which makes torch.library unhappy + dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + dropout_p, + rowscale, + has_residual, + has_x1, + zero_centered_weight, + is_rms_norm, + x_dtype=x_dtype, + recompute_output=recompute_output, + ) + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y + + + +@triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, + schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", + allow_decomposition=False, # Don't let torch.compile trace inside + ) +def _layer_norm_bwd_impl( + dy: Tensor, + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + mean: Tensor, + rstd: Tensor, + dresidual: Optional[Tensor] = None, + dy1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + seeds: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + has_residual: bool = False, + has_x1: bool = False, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + x_dtype: Optional[torch.dtype] = None, + recompute_output: bool = False, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 + dy = maybe_contiguous_lastdim(dy) assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: + dresidual = maybe_contiguous_lastdim(dresidual) assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) @@ -599,6 +736,7 @@ def _layer_norm_bwd( assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: + dy1 = maybe_contiguous_lastdim(dy1) assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 @@ -636,7 +774,9 @@ def _layer_norm_bwd( BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) @@ -648,7 +788,7 @@ def _layer_norm_bwd( rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): - _layer_norm_bwd_kernel[grid]( + torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( x, weight, bias, @@ -680,6 +820,8 @@ def _layer_norm_bwd( N, eps, dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), rows_per_program, is_rms_norm, BLOCK_N, @@ -687,24 +829,22 @@ def _layer_norm_bwd( dresidual_in is not None, bias is not None, dropout_p > 0.0, + HAS_ROWSCALE=rowscale is not None, + HAS_DY1=dy1 is not None, + HAS_DX1=dx1 is not None, + HAS_B1=bias1 is not None, + RECOMPUTE_OUTPUT=y is not None, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None - # Don't need to compute dresidual_in separately in this case - if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: - dresidual_in = dx - if has_x1 and dropout_p == 0.0: - dx1 = dx - return ( - (dx, dw, db, dresidual_in, dx1, dw1, db1) - if not recompute_output - else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) - ) + # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx + return dx, dw, db, dresidual_in, dx1, dw1, db1, y class LayerNormFn(torch.autograd.Function): + @staticmethod def forward( ctx, @@ -720,34 +860,27 @@ def forward( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): x_shape_og = x.shape # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" - x1 = x1.reshape(-1, x1.shape[-1]) - if x1.stride(-1) != 1: - x1 = x1.contiguous() + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - if weight1 is not None: - weight1 = weight1.contiguous() - if bias1 is not None: - bias1 = bias1.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( @@ -770,11 +903,13 @@ def forward( bias1, dropout_p=dropout_p, rowscale=rowscale, + out_dtype=out_dtype, residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, - residual_out=residual_out + residual_out=residual_out, ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd @@ -787,6 +922,7 @@ def forward( ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None @@ -815,26 +951,19 @@ def forward( def backward(ctx, dy, *args): x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) - if dy.stride(-1) != 1: - dy = dy.contiguous() - assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) - if dy1.stride(-1) != 1: - dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None - dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( dy, x, weight, @@ -851,8 +980,10 @@ def backward(ctx, dy, *args): rowscale, ctx.has_residual, ctx.has_x1, + ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, + recompute_output=False, ) return ( dx.reshape(ctx.x_shape_og), @@ -871,6 +1002,8 @@ def backward(ctx, dy, *args): None, None, None, + None, + None, ) @@ -887,8 +1020,10 @@ def layer_norm_fn( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): @@ -905,8 +1040,10 @@ def layer_norm_fn( rowscale, prenorm, residual_in_fp32, + zero_centered_weight, is_rms_norm, return_dropout_mask, + out_dtype, out, residual_out ) @@ -925,7 +1062,9 @@ def rms_norm_fn( rowscale=None, prenorm=False, residual_in_fp32=False, + zero_centered_weight=False, return_dropout_mask=False, + out_dtype=None, out=None, residual_out=None ): @@ -942,8 +1081,10 @@ def rms_norm_fn( rowscale, prenorm, residual_in_fp32, + zero_centered_weight, True, return_dropout_mask, + out_dtype, out, residual_out ) @@ -951,7 +1092,8 @@ def rms_norm_fn( class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps @@ -959,12 +1101,16 @@ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None + self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): - torch.nn.init.ones_(self.weight) + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( @@ -976,10 +1122,12 @@ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, ) class LayerNormLinearFn(torch.autograd.Function): + @staticmethod @custom_fwd def forward( @@ -997,17 +1145,12 @@ def forward( ): x_shape_og = x.shape # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og - residual = residual.reshape(-1, residual.shape[-1]) - if residual.stride(-1) != 1: - residual = residual.contiguous() + residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) norm_weight = norm_weight.contiguous() - if norm_bias is not None: - norm_bias = norm_bias.contiguous() + norm_bias = maybe_contiguous(norm_bias) residual_dtype = ( residual.dtype if residual is not None @@ -1019,12 +1162,12 @@ def forward( norm_bias, eps, residual, - out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, ) y = y.reshape(x_shape_og) - dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype linear_weight = linear_weight.to(dtype) linear_bias = linear_bias.to(dtype) if linear_bias is not None else None out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) @@ -1046,14 +1189,11 @@ def backward(ctx, dout, *args): dout = dout.reshape(-1, dout.shape[-1]) dy = F.linear(dout, linear_weight.t()) dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) - if dy.stride(-1) != 1: - dy = dy.contiguous() + dy = maybe_contiguous_lastdim(dy) assert dy.shape == x.shape if ctx.prenorm: dresidual = args[0] - dresidual = dresidual.reshape(-1, dresidual.shape[-1]) - if dresidual.stride(-1) != 1: - dresidual = dresidual.contiguous() + dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) assert dresidual.shape == x.shape else: dresidual = None diff --git a/flash_attn/ops/triton/mlp.py b/flash_attn/ops/triton/mlp.py index b795310f1c..059f4f8a5e 100644 --- a/flash_attn/ops/triton/mlp.py +++ b/flash_attn/ops/triton/mlp.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.cuda.amp import custom_bwd, custom_fwd +from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 0ee56d6477..ff4017fda3 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -1,4 +1,5 @@ -# Copyright (c) 2023, Tri Dao. +# Copyright (c) 2025, Tri Dao. +# As of 2025-04-23, we require triton >= 3.0 from typing import Optional, Union @@ -18,7 +19,7 @@ def rotary_kernel( SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, - rotary_dim, + nheads, seqlen_ro, # strides stride_out_batch, @@ -30,104 +31,72 @@ def rotary_kernel( stride_x_nheads, stride_x_headdim, # Meta-parameters - BLOCK_K: tl.constexpr, + # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that + # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 + ROTARY_DIM: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, BLOCK_M: tl.constexpr, ): - pid_m = tl.program_id(axis=0) - pid_batch = tl.program_id(axis=1) - pid_head = tl.program_id(axis=2) - rotary_dim_half = rotary_dim // 2 + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) if not IS_VARLEN: - X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads - OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch else: start_idx = tl.load(CU_SEQLENS + pid_batch) seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx - X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads - OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen if pid_m * BLOCK_M >= seqlen: return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) if not IS_SEQLEN_OFFSETS_TENSOR: rm_cs = rm + SEQLEN_OFFSETS else: rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) - rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin if not INTERLEAVED: # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT - X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) - cos = tl.load( - COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 - ).to(tl.float32) - sin = tl.load( - SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x0 = tl.load( - X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 - ).to(tl.float32) - x1 = tl.load( - X + rotary_dim_half * stride_x_headdim, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - if CONJUGATE: - sin = -sin + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) o0 = x0 * cos - x1 * sin o1 = x0 * sin + x1 * cos - # write back result - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) - tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) - tl.store( - OUT + rotary_dim_half * stride_out_headdim, - o1, - mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), - ) + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) else: - # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. - # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. - # Loading x0 will be fast but x1 will be slow. - # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. - # Then we do the calculation and use tl.where to pick put the right outputs for the even - # and for the odd indices. - rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... - rk_repeat = tl.arange(0, BLOCK_K) // 2 - X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) - X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) - COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) - cos = tl.load( - COS, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=1.0, - ).to(tl.float32) - sin = tl.load( - SIN, - mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), - other=0.0, - ).to(tl.float32) - x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( - tl.float32 - ) - x1 = tl.load( - X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 - ).to(tl.float32) - if CONJUGATE: - sin = -sin - x0_cos = x0 * cos - x1_sin = x1 * sin - out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) - OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) - tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + rk = tl.arange(0, BLOCK_K) + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) def apply_rotary( @@ -169,13 +138,6 @@ def apply_rotary( assert headdim <= 256, "Only support headdim <= 256" assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" - assert ( - cos.dtype == sin.dtype - ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" - assert ( - x.dtype == cos.dtype - ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" - cos, sin = cos.contiguous(), sin.contiguous() if isinstance(seqlen_offsets, torch.Tensor): assert seqlen_offsets.shape == (batch,) @@ -188,18 +150,13 @@ def apply_rotary( if rotary_dim < headdim and not inplace: output[..., rotary_dim:].copy_(x[..., rotary_dim:]) - BLOCK_K = ( - 32 - if rotary_dim <= 32 - else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) - ) - grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa - BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4) + grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa + BLOCK_M = 8 if rotary_dim <= 128 else 4 # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(x.device.index): - rotary_kernel[grid]( + torch.library.wrap_triton(rotary_kernel)[grid]( output, # data ptrs x, cos, @@ -207,7 +164,7 @@ def apply_rotary( cu_seqlens, seqlen_offsets, seqlen, # shapes - rotary_dim, + nheads, seqlen_ro, output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride @@ -217,11 +174,12 @@ def apply_rotary( x.stride(-3), # seqlen stride or total_seqlen_stride x.stride(-2), # nheads stride x.stride(-1), # headdim stride - BLOCK_K, + rotary_dim, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, - BLOCK_M, + BLOCK_M=BLOCK_M, + BLOCK_H=2, ) return output diff --git a/flash_attn/pyproject.toml b/flash_attn/pyproject.toml index 3201555763..ce5eac916c 100644 --- a/flash_attn/pyproject.toml +++ b/flash_attn/pyproject.toml @@ -1,3 +1,6 @@ [tool.black] line-length = 100 -target-version = ['py38'] \ No newline at end of file +target-version = 'py39' +[tool.ruff] +line-length = 100 +target-version = 'py39' \ No newline at end of file diff --git a/flash_attn/utils/library.py b/flash_attn/utils/library.py new file mode 100644 index 0000000000..05324bb01a --- /dev/null +++ b/flash_attn/utils/library.py @@ -0,0 +1,66 @@ +# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py +# The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema. + +from typing import Optional, Callable, Iterable, Union + +from torch.library import custom_op, CustomOpDef +from torch._library.triton import set_wrap_triton_enabled + + +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, + # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False, + # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator + # and so inductor can't trace inside. + allow_decomposition=True, +) -> Callable: + def dec(fn: Callable[..., object]) -> CustomOpDef: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_wrap_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + # This is the only difference with the PyTorch implementation + schema=schema, + ) + from torch._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + if allow_decomposition: + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + def functional_decomp( # type: ignore[no-untyped-def] + mode, op, types, args, kwargs + ): + from torch.export._trace import custom_triton_ops_decomposition_disabled + + if custom_triton_ops_decomposition_disabled(): + return mode.__torch_dispatch__(op, types, args, kwargs) + else: + with mode: + return fn(*args, **kwargs) + + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + + return result + + if fn is None: + return dec + else: + return dec(fn) diff --git a/flash_attn/utils/testing.py b/flash_attn/utils/testing.py new file mode 100644 index 0000000000..81be51f1de --- /dev/null +++ b/flash_attn/utils/testing.py @@ -0,0 +1,360 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +import math +from typing import Optional + +import torch +from einops import rearrange, repeat + +from flash_attn.bert_padding import pad_input, unpad_input + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint( + max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) + elif mode == "third": + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + ) + return padding_mask + + +def generate_qkv( + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, + query_unused_mask=None, key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + else: + dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + qv.detach() if qv is not None else None, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(None, None), + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] is None: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + ) + + +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # Subtract remainder instead of divide and then multiply to take care of negative values + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(None, None), + attention_chunk=0, + sink_token_length=0, + learnable_sink: Optional[torch.Tensor] = None, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None + if q_descale is not None: + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None + if k_descale is not None: + k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) + if v_descale is not None: + v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) + if softcap > 0: + scores = torch.tanh(scores / softcap) * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + if window_size[0] is not None or window_size[1] is not None: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + if learnable_sink is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + learnable_sink = rearrange(learnable_sink, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) + attention = (unnormalized_scores / normalizer).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + # Without this we might get NaN in dv + if key_padding_mask is not None: + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_attn/utils/torch.py b/flash_attn/utils/torch.py new file mode 100644 index 0000000000..98cbf9a274 --- /dev/null +++ b/flash_attn/utils/torch.py @@ -0,0 +1,21 @@ +import torch +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5f7522a8ac..33e5d28271 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -1,6 +1,7 @@ from collections import namedtuple from functools import partial import math +import os from typing import NamedTuple import torch import torch.nn as nn @@ -34,6 +35,8 @@ triton_attention = None triton_attention = None +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" + def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # # Warmup @@ -53,10 +56,10 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) # # return time_f[1].mean # return time_f[1] - return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) + return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3) -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)): +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: @@ -67,7 +70,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size= col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2 + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) def convert_to_cudnn_type(torch_type): @@ -242,21 +245,6 @@ def run(*args, **kwargs): time_f = {} time_b = {} -# tflops_matmul = {} -# m, n = 8192, 8192 -# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]: -# a = torch.randn(m, k, device=device, dtype=dtype) -# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) -# nFLOPS_matmul = 2 * m * n * k -# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS') -# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') -# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12 -# # import pickle -# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: -# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp: -# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL) -# exit(0) - # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: @@ -265,6 +253,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128, 192, 256]: for headdim in [128]: nheads = dim // headdim + # nheads = 128 # headdim = 64 # batch_size = 64 # seqlen = 512 @@ -272,12 +261,16 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + # nheads_kv = 1 + headdim_v = headdim + # headdim_v = 512 + has_qv = headdim == 64 and headdim_v == 512 + # has_qv = False for batch_size, seqlen in bs_seqlen_vals: - num_splits = 1 + num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) - sink_token_length = 0 pack_gqa = None # seqlen_q = 64 seqlen_q = seqlen @@ -285,20 +278,17 @@ def run(*args, **kwargs): # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor + qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) - a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) - b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) - # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype) - # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2) if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q @@ -318,16 +308,16 @@ def run(*args, **kwargs): page_table = None for causal in [False, True]: - # for causal in [False]: + # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: if not varlen: m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') @@ -343,7 +333,7 @@ def run(*args, **kwargs): repeats=repeats, verbose=False, desc='Fav2') time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]] time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark @@ -356,7 +346,7 @@ def run(*args, **kwargs): # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean @@ -368,25 +358,20 @@ def run(*args, **kwargs): time.sleep(1) if not varlen: - # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') - m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: - m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') + m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - # time.sleep(1) - # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False) - # nFLOPS_matmul = nFLOPS - # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] - # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: time.sleep(1) if not varlen: - _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: - _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, + _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean # time.sleep(1) @@ -396,11 +381,11 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS') # if causal: @@ -409,7 +394,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') @@ -419,7 +404,8 @@ def run(*args, **kwargs): # import pickle # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp: - # with open(f'flash3_attn_time_h100_fa3_20241208.plk', 'wb') as fp: + # with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp: + # # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp: # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/hopper/benchmark_mla_decode.py b/hopper/benchmark_mla_decode.py new file mode 100644 index 0000000000..99b1b7a329 --- /dev/null +++ b/hopper/benchmark_mla_decode.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, Ted Zadouri, Tri Dao. + +# We recommend locking GPU clocks before running the benchmark to ensure consistent results. +# This can be done using the following commands (1830 MHz is the clock for H100): +# sudo nvidia-smi -i 0 -pm 1 +# sudo nvidia-smi -i 0 --lock-gpu-clocks 1830,1830 +# See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487 + +import time +import torch +import torch.nn.functional as F + +from triton.testing import do_bench, do_bench_cudagraph + +from einops import rearrange + +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + +try: + from flash_mla import flash_mla_with_kvcache, get_mla_metadata +except ImportError: + flash_mla_with_kvcache, get_mla_metadata = None, None + +try: + from flash_attn.utils.benchmark import pytorch_profiler +except ImportError: + pytorch_profiler = None + + +device = "cuda" +dtype = torch.bfloat16 +seqlen = 8192 +seqlen_q = 1 +# nheads_q = 16 +nheads_q = 128 + +use_bench_cudagraph = False + +attn_variants = ["mha", "gqa", "mqa", "mla", "gla"] +# for attn_variant in attn_variants: +for attn_variant in attn_variants[3:5]: + nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else (1 if attn_variant == "mla" else 2)) + headdim = 64 if attn_variant in ["mla", "gla"] else 128 + headdim_v = 512 if attn_variant == "mla" else (256 if attn_variant == "gla" else headdim) + has_qv = headdim == 64 and headdim_v > 64 + # page_size = None + page_size = 64 if attn_variant in ["mla", "gla"] else 128 + + should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None + + torch.manual_seed(0) + + batch_size = 128 + cache_seqlens = None + # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) + # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) + # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) + + print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") + + for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: + # for seqlen in [s * 1024 for s in [8]]: + cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) + num_splits = 0 + q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) + try: + v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) + k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) + if page_size is not None: + assert seqlen % page_size == 0 + k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] + page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), + "(b s) -> b s", s=seqlen // page_size) + else: + page_table = None + except torch.OutOfMemoryError: + continue + qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None + + # Precomputing this saves ~2us + scheduler_metadata = get_scheduler_metadata( + batch_size, seqlen_q, seqlen, nheads_q, nheads_kv, headdim, + cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True + ) + # scheduler_metadata = None + # breakpoint() + fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) + time.sleep(1) # to avoid power throttling + # Time in ms + if not use_bench_cudagraph: + t0 = do_bench(fn0, warmup=1, rep=10) + else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready + with torch.cuda.stream(torch.cuda.Stream()): + t0 = do_bench_cudagraph(fn0, rep=10) + # exit(0) + if should_run_flashmla: + # Separate out the preprocessing since this can be done once and reused for all layers + mla_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) + q_concat = torch.concat([q, qv], dim=-1) if has_qv else q + kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) + fn1 = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *mla_metadata, causal=True) + time.sleep(1) # to avoid power throttling + if not use_bench_cudagraph: + t1 = do_bench(fn1, warmup=1, rep=10) + else: + torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready + with torch.cuda.stream(torch.cuda.Stream()): + t1 = do_bench_cudagraph(fn1, rep=10) + + total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() + mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last term is for the output + flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 + ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 + ideal_h100_time_flop = flops / 989e12 * 1e6 + ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) + print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.1f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") + if should_run_flashmla: + print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.1f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") + print(f"Arithmetic intensity: {flops / mem_io:.1f}") + print(f"Ideal time: {ideal_h100_time:.0f} us") + + # if pytorch_profiler is not None: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn0) + # if should_run_flashmla: + # time.sleep(1) # to avoid power throttling + # pytorch_profiler(fn1) diff --git a/hopper/benchmark_split_kv.py b/hopper/benchmark_split_kv.py index d3d83590a9..f3c8af9177 100644 --- a/hopper/benchmark_split_kv.py +++ b/hopper/benchmark_split_kv.py @@ -18,13 +18,13 @@ def timeit(fn, *args, **kwargs): # Warmup for _ in range(5): fn(*args, **kwargs) - + # Benchmark using PyTorch Timer t = benchmark.Timer( stmt='fn(*args, **kwargs)', globals={'fn': fn, 'args': args, 'kwargs': kwargs} ) - + # Measure execution time measurement = t.timeit(20) # Runs the function 20 times # measurement = t.blocked_autorange(min_run_time=1) @@ -38,14 +38,15 @@ def main(): ).multi_processor_count max_splits = 129 - check_all_splits = False + check_all_splits = True causal = True # causal = False # dtype=torch.float16 dtype=torch.bfloat16 + tp_degree = 1 - torch.manual_seed(42) + torch.manual_seed(42) model_configs = [ # ("Gemma-2-2B", 8, 4, 256), @@ -56,6 +57,7 @@ def main(): # ("Qwen-2.5-7B", 28, 4, 128), # ("Llama-3.1-8B", 32, 8, 128), ("Llama-3.1-70B", 64, 8, 128), + # ("Mistral Large", 96, 8, 128), # ("Llama-3.1-405B", 128, 8, 128), # ("Llama-3.2-1B", 32, 8, 64), # ("Llama-3.2-3B", 24, 8, 128), @@ -66,28 +68,32 @@ def main(): all_batch_configs.extend(itertools.product( # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen - [4096, 16384, 65536], # context_seqlen - # [131072], # context_seqlen + # [4096, 16384, 65536], # context_seqlen + [131072], # context_seqlen # [i for i in range(1, (num_sms) + 1)], # num_requests [1, 4, 8, 16], # num_requests # [1], # num_requests - [1, 4, 8, 16], # query_seqlen - # [1], # query_seqlen + # [1, 4, 8, 16], # query_seqlen + [1], # query_seqlen )) num_caches = max(reqs for _, reqs, _ in all_batch_configs) cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs) for model_name, nheads_q, nheads_kv, headdim in model_configs: + assert nheads_kv % tp_degree == 0 + print(f"***{model_name}***") + print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}") + nheads_q //= tp_degree + nheads_kv //= tp_degree + k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) - print(f"***{model_name}***") - print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}") - + if check_all_splits is False: print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}") @@ -139,7 +145,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=1, ) * 1000. * 1000. @@ -151,16 +157,16 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=0, - max_seqlen_k_hint=context_seqlen + # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. if check_all_splits: - + fa3_fastest_num_splits = 0 fa3_fastest_splitk_time = float("inf") - + for num_splits in range(1, max_splits): t = timeit( flash_attn_interface.flash_attn_with_kvcache, @@ -170,7 +176,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=num_splits ) * 1000. * 1000. @@ -181,7 +187,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=num_splits ) @@ -192,7 +198,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=False, + pack_gqa=False, num_splits=1 ) @@ -220,7 +226,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=num_splits ) * 1000. * 1000. @@ -231,7 +237,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=num_splits ) @@ -242,7 +248,7 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=1 ) @@ -257,7 +263,7 @@ def main(): if t < fa3_fastest_splitk_time_gqa: fa3_fastest_splitk_time_gqa = t fa3_fastest_num_splits_gqa = num_splits - + efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa # remeasure to smooth anomalies @@ -271,11 +277,11 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, # num_splits=num_splits_select, # num_splits=1, num_splits=0, - max_seqlen_k_hint=context_seqlen + # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. fa3_fastest_splitk_time_gqa = timeit( @@ -286,9 +292,9 @@ def main(): cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, - gqa_parallel=True, + pack_gqa=True, num_splits=fa3_fastest_num_splits_gqa - ) * 1000. * 1000. + ) * 1000. * 1000. if check_all_splits is True: print( @@ -308,7 +314,7 @@ def main(): # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, " f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, " f"EFF:{efficiency:.2f}, " - f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" + f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" ) if check_all_splits is False: @@ -322,4 +328,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/hopper/block.h b/hopper/block.h new file mode 100644 index 0000000000..e69eede49a --- /dev/null +++ b/hopper/block.h @@ -0,0 +1,139 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +template +struct BlockMN { + + static + CUTLASS_DEVICE + cute::tuple get_n_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_q = seqlen_info.seqlen_q; + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal || Is_local) { + int m_idx_max = (m_block + 1) * kBlockM; + // TODO: check off-by-1 error + if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); + } + n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN)); + } + int n_block_min = 0; + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + int const n_idx = m_idx_min + seqlen_k - seqlen_q; + int n_idx_left = n_idx - window_size_left; + if (attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + n_block_min = std::max(int(0), n_idx_left / kBlockN); + } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + if constexpr (Split) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + int split_idx_actual = split_idx & 0x0000FFFF; + int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); + n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; + n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } + } + // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } + return {n_block_min, n_block_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_n_block_k_new_min_max( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const bidb, int const split_idx, int const num_splits, + int const window_size_left, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + + auto [n_block_min, n_block_max] = get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, num_splits, + window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod); + int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); + int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); + int const n_block_new_min = idx_k_new_min / kBlockN; + int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; + // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} + return {n_block_new_min, n_block_new_max}; + } + + static + CUTLASS_DEVICE + cute::tuple get_m_block_min_max( + SeqlenInfo_t const& seqlen_info, + int const n_block, int const bidb, + int const window_size_left, int const window_size_right, int const sink_token_length) { + // TODO: support attention_chunk + int const seqlen_q = seqlen_info.seqlen_q; + int const seqlen_k = seqlen_info.seqlen_k; + int m_block_max = cute::ceil_div(seqlen_q, kBlockM); + if constexpr (Is_local) { + if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); + } + } + int m_block_min = 0; + if constexpr (Is_causal || Is_local) { + m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); + } + return {m_block_min, m_block_max}; + } + + // If we have separate iterations with causal or local masking at the start, where do we stop + static + CUTLASS_DEVICE + int get_n_block_min_causal_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_right, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM); + int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); + } + return std::max(n_block_min, n_idx_right / kBlockN); + } + + // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations + static + CUTLASS_DEVICE + int get_n_block_min_before_local_mask( + SeqlenInfo_t const& seqlen_info, + int const m_block, int const n_block_min, int const window_size_left, + cutlass::FastDivmod const& attention_chunk_divmod, + cutlass::FastDivmod const& qhead_per_khead_divmod) { + int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; + int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left; + if (Is_local && attention_chunk_divmod.divisor > 0) { + n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); + } + return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN)); + } + +}; + +} // namespace flash diff --git a/hopper/combine.h b/hopper/combine.h deleted file mode 100644 index c26f7ea562..0000000000 --- a/hopper/combine.h +++ /dev/null @@ -1,248 +0,0 @@ - -#pragma once - -#include - -#include -#include "cutlass/layout/layout.h" -#include -#include - -#include "kernel_traits.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SharedStorageLSE { - cute::array_aligned> smem_lse; - cute::array_aligned> smem_valid_splits; -}; - -// DONT use Kernel_traits here to avoid redundant compilation. -// template -template -__global__ void combine_attn_seqk_parallel(Params const params) { - // using Element = typename Kernel_traits::OutputType; - // using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = int64_t; // Kernel_traits::index_t - constexpr int kMaxSplits = 1 << Log_max_splits; - // constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNThreads = 128; //Kernel_traits::kNThreads; - - static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); - static_assert(kNThreads == 128, "We assume that each block has 128 threads"); - - // Shared memory. - // kBlockM + 1 instead of kBlockM to reduce bank conflicts. - //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; - extern __shared__ char smem_[]; - using SharedStorage = SharedStorageLSE, Int>, Shape>>; - SharedStorage &shared_storage = - *reinterpret_cast(smem_); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); - Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); - - // The thread and block index. - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - - const index_t lse_size = params.b * params.h * params.seqlen_q; - //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - - const index_t row_offset_lse = bidx * kBlockM; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), - Shape, Int>{}, - make_stride(lse_size, _1{})); - - // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. - // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. - Layout flat_layout = make_layout(lse_size); - Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); - auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); - Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); - Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); - - Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); - - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - - // Read the LSE values from gmem and store them in shared memory, then transpose them. - constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadLSE + tidx / kBlockM; - const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; - if (row < kMaxSplits) { sLSE(row,col) = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } - } - __syncthreads(); - - // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) - // One thread per split. Know NumThreads = 128 >= NumMaxSplits - if (tidx < kMaxSplits) { - bool is_valid_split = false; - #pragma unroll - for (int col = 0; col < kBlockM; ++col) { - if(sLSE(tidx,col) != -INFINITY) { - is_valid_split = true; - } - } - sValidSplits(tidx) = is_valid_split; - } - __syncthreads(); - // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } - - Tensor lse_accum = make_tensor(Shape>{}); - constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); - // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits - // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // kBlockM rows, so each time we load we can load 128 / kBlockM rows). - // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; - // static_assert(kThreadsPerSplit <= 32); - static_assert(kRowsPerLoadTranspose <= 32); - static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; - - } - //return; - - // Compute the logsumexp of the LSE along the split dimension. - ElementAccum lse_max = lse_accum(0); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } - SumOp sum_op; - lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - if (params.unpadded_lse) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (lse_offset < lse_size) { - gLSE_unpadded(lse_offset) = lse_logsum; - } - } else { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; - } - } - //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); - - // Store the scales exp(lse - lse_logsum) in shared memory. - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } - } - __syncthreads(); - - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); - constexpr int kBlockN = kNThreads / kBlockM; - using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); - //if (cute::thread0()) print_tensor (cOaccum); - // Repeat the partitioning with identity layouts - Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); - Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } - } - // Load Oaccum in then scale and accumulate to O - for (int split = 0; split < params.num_splits; ++split) { - // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. - if(sValidSplits(split)) { - flash::copy( - gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOrOaccum); ++m) { - int row = get<0>(tOcOaccum(0, m, 0)); - ElementAccum lse_scale = sLSE(split,row); - if (lse_scale != 0.f) { - #pragma unroll - for (int k = 0; k < size<2>(tOrOaccum); ++k) { - #pragma unroll - for (int i = 0; i < size<0>(tOrOaccum); ++i) { - tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); - //tOrO(i, m, k) += tOrOaccum(i, m, k); - } - } - } - //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } - } - } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; - } - //if (cute::thread0()) { print_tensor(tOrO); } - - Tensor rO = flash::convert_type(tOrO); - // Write to gO - #pragma unroll - for (int m = 0; m < size<1>(rO); ++m) { - const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); - //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - if (idx < params.b * params.h * params.seqlen_q) { - //print ("final2\n"); - const int batch_idx = idx / (params.h * params.seqlen_q); - const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; - // The index to the rows of Q - const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride - + head_idx * params.o_head_stride + row * params.o_row_stride; - #pragma unroll - for (int k = 0; k < size<2>(rO); ++k) { - if (Is_even_K || tOpOaccum(k)) { - const int col = get<1>(tOcOaccum(0, m, k)); - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), - Shape(rO))::value>>{}, Stride<_1>{}); - // TODO: Should check if this is using vectorized store, but it seems pretty fast - copy(rO(_, m, k), gO); - //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } - // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } - // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); - } - } - } - } -} - -} // namespace flash diff --git a/hopper/cuda_check.h b/hopper/cuda_check.h new file mode 100644 index 0000000000..b5e63aef79 --- /dev/null +++ b/hopper/cuda_check.h @@ -0,0 +1,19 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) diff --git a/hopper/epilogue_bwd.hpp b/hopper/epilogue_bwd.hpp index f99dfe918e..6d9b5f4f59 100644 --- a/hopper/epilogue_bwd.hpp +++ b/hopper/epilogue_bwd.hpp @@ -4,8 +4,8 @@ #pragma once -#include -#include +#include "cutlass/cutlass.h" +#include "cutlass/barrier.h" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" @@ -107,6 +107,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; int const num_heads_q; int* dk_semaphore; @@ -121,6 +122,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; @@ -130,7 +132,7 @@ struct CollectiveEpilogueBwd { static Params to_underlying_arguments(Arguments const& args) { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); - Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); TMA_dKV tma_store_dK = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV @@ -145,7 +147,7 @@ struct CollectiveEpilogueBwd { return nullptr; } }(); - return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } @@ -197,7 +199,7 @@ struct CollectiveEpilogueBwd { cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); - Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); @@ -227,7 +229,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -238,28 +240,31 @@ struct CollectiveEpilogueBwd { Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdKVrdV = make_fragment_like(tdKVgdV); Tensor tdKVrdK = make_fragment_like(tdKVgdK); - Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } // Need to check OOB when reading from smem if kBlockN isn't evenly tiled static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( - gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN); flash::copy( - gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); // Construct identity layout for gdKV // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); } } @@ -282,7 +287,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -295,15 +300,18 @@ struct CollectiveEpilogueBwd { Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN ); } @@ -359,6 +367,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; int num_heads_q; int* dk_semaphore; @@ -373,6 +382,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; @@ -387,7 +397,7 @@ struct CollectiveEpilogueBwdGQA { assert(args.dk_semaphore != nullptr); assert(args.dv_semaphore != nullptr); } - return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, args.cu_seqlens, args.seqused}; @@ -419,7 +429,7 @@ struct CollectiveEpilogueBwdGQA { flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 0f91606026..fd3485ce6b 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -20,38 +20,41 @@ namespace flash { using namespace cute; -template +template struct CollectiveEpilogueFwd { - using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; + using ElementPartial = float; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; - static constexpr bool Use_smem = sizeof(Element) <= 2; - static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && Use_smem && !PackGQA; + static constexpr bool Split = Split_; + static constexpr bool Use_smem = !(Split && !Varlen); + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); + static_assert(sizeof(Element) <= 2); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + + static constexpr bool LargeHeadDimV = kHeadDimV > 256; using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -65,15 +68,15 @@ struct CollectiveEpilogueFwd { Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) @@ -86,10 +89,11 @@ struct CollectiveEpilogueFwd { using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + using EpilogueTile_MN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom; @@ -109,7 +113,7 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, - select<0, 2>(TileShape_MNK{}), + select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O std::nullptr_t >; @@ -119,8 +123,12 @@ struct CollectiveEpilogueFwd { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; float* ptr_LSE; StrideLSE const stride_LSE; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; int32_t const nheads_kv; int const* cu_seqlens = nullptr; int const* seqused = nullptr; @@ -133,10 +141,16 @@ struct CollectiveEpilogueFwd { StrideO const stride_O; ShapeOPacked const shape_O_packed; StrideOPacked const stride_O_packed; + ElementPartial* ptr_O_partial; + StrideO const stride_O_partial; + StrideOPacked const stride_O_partial_packed; float* ptr_LSE; StrideLSE const stride_LSE; ShapeLSEPacked const shape_LSE_packed; StrideLSEPacked const stride_LSE_packed; + float* ptr_LSE_partial; + StrideLSE const stride_LSE_partial; + StrideLSEPacked const stride_LSE_partial_packed; cutlass::FastDivmod qhead_per_khead_divmod; TMA_O tma_store_O; int const* cu_seqlens = nullptr; @@ -148,7 +162,7 @@ struct CollectiveEpilogueFwd { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast } else { return nullptr; } @@ -163,6 +177,10 @@ struct CollectiveEpilogueFwd { args.stride_O, make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) ); + auto const stride_O_partial_packed = cute::conditional_return( + args.stride_O_partial, + make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) + ); // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), @@ -172,8 +190,14 @@ struct CollectiveEpilogueFwd { args.stride_LSE, make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) ); + auto const stride_LSE_partial_packed = cute::conditional_return( + args.stride_LSE_partial, + make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) + ); return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, + args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, + args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, cutlass::FastDivmod(qhead_per_khead), tma_store_O, args.cu_seqlens, args.seqused}; } @@ -189,7 +213,7 @@ struct CollectiveEpilogueFwd { template CUTLASS_DEVICE void store(Params const& params, - FrgTensorO const& tOrO, + FrgTensorO& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, @@ -198,12 +222,25 @@ struct CollectiveEpilogueFwd { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); + static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); + // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. + // Otherwise we can permute after conversion. + if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } Tensor tOrO_out = make_tensor_like(tOrO); flash::convert_type_out(tOrO, tOrO_out); - if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } + if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } // Make sure all WGs have finished reading V // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that @@ -239,35 +276,41 @@ struct CollectiveEpilogueFwd { bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } - if constexpr (!PackGQA) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } - } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) @@ -286,10 +329,10 @@ struct CollectiveEpilogueFwd { } } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } - if constexpr (Use_smem) { + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -305,7 +348,7 @@ struct CollectiveEpilogueFwd { } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } @@ -316,17 +359,27 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { - // We already arrived on barrier_O earlier + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + // We already arrived on barrier_O earlier if !Use_smem + if constexpr (Use_smem) { + if constexpr (ArchTag::kMinComputeCapability >= 90) { + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.pipelines.barrier_O.arrive(cta_id); + } + } + } if constexpr (!PackGQA) { static constexpr int kGmemElemsPerStoreDirect = 2; - cute::Copy_Atom, Element> gmem_copy_direct; + cute::Copy_Atom, ElementPartial> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) - Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); + Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); - Tensor tOgO = thread_mma.partition_C(gO); + Tensor tOgO = thread_mma.partition_C(gOpartial); Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); @@ -342,7 +395,7 @@ struct CollectiveEpilogueFwd { } } } else { - PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } @@ -354,22 +407,31 @@ struct CollectiveEpilogueFwd { } // Write 0 to output and -inf to LSE - template CUTLASS_DEVICE void store_zero( Params const& params, int thread_idx, cute::tuple const& block_coord ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; + int num_splits = get<4>(params.shape_O_packed); + if constexpr (Split && Varlen) { + uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits + int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); + num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; + split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx + } + bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); + flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); + Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), + params.shape_LSE_packed, + !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); @@ -381,35 +443,39 @@ struct CollectiveEpilogueFwd { if (row < seqlen_o * qhead_per_khead) { int m_idx, h_idx; m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); - // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" + // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; } } } - if constexpr (!Clear_O) { return; } - - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); - if constexpr (!PackGQA) { - Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_fragment_like(tOgO); - cute::clear(tOrO); - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM - ); - } else { - // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; - Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); - cute::clear(tOrO); - PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, + // since it will not use the value of O if LSE is -inf. + if (!is_split) { + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + if constexpr (!PackGQA) { + Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + cute::clear(tOrO); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM + ); + } else { + // If PackGQA, we split the work of compute O_ptr among threads in the same row + using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; + Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); + cute::clear(tOrO); + PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); + } } } diff --git a/hopper/flash.h b/hopper/flash.h index 4559a1352e..6848e8c9db 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -65,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params { int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim // The scaling factors for the kernel. float scale_softmax; @@ -103,9 +104,15 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride; index_t vnew_head_stride; + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; + int *__restrict__ seqlens_rotary; // The indices to index into the KV cache. int * __restrict__ kv_batch_idx; @@ -115,6 +122,7 @@ struct Flash_fwd_params : public Qkv_params { index_t page_table_batch_stride; int page_size; int num_pages; + bool pagedkv_tma; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -127,7 +135,7 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; - int sink_token_length; + int attention_chunk; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; @@ -144,6 +152,16 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; + int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int * __restrict__ num_splits_dynamic_ptr; + int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual + int * __restrict__ num_nheads_in_l2_ptr; + bool skip_scheduler_metadata_computation; + bool varlen_sort_batches; + int tile_count_semaphore_offset; + bool head_swizzle; + bool prepare_varlen_pdl; int arch; int num_sm; @@ -197,9 +215,10 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 82643d9fff..adb53fdab6 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -2,11 +2,9 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#include -#include -#include // For TORCH_VERSION* macros -#include +#include +#include +#include #include #include @@ -15,50 +13,34 @@ #include "static_switch.h" #include "tile_size.h" #include "heuristics.h" - -// Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909 -// This is so that we can pass in torch.dtype as a parameter to the function. -#if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4) - -#include -#include - -namespace pybind11::detail { - - template <> - struct type_caster { - public: - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype")); - // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType - // cannot be default-initialized, we provide this constructor to explicitly - // initialize that field. The value doesn't matter as it will be overwritten - // after a successful call to load. - type_caster() : value(at::kFloat) {} - bool load(handle src, bool) { - PyObject* obj = src.ptr(); - if (THPDtype_Check(obj)) { - value = reinterpret_cast(obj)->scalar_type; - return true; - } - return false; - } - static handle cast( - const at::ScalarType& src, - return_value_policy /* policy */, - handle /* parent */) { - return Py_NewRef(torch::getTHPDtype(src)); - } +#include "cuda_check.h" + + +extern "C" { +/* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the TORCH_LIBRARY static initializers + below are run. */ +PyObject* PyInit__C(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ }; - -} // namespace pybind11::detail - -#endif + return PyModule_Create(&module_def); +} +} #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -84,6 +66,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, + int attention_chunk, const float softcap=0.f, const int sm_margin=0) { @@ -156,14 +139,19 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; // TODO: check this - if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; } - if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (window_size_left < 0) { window_size_left = seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } params.window_size_left = window_size_left; params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; @@ -206,6 +194,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, + int attention_chunk, const float softcap=0.f, bool deterministic=false, int const sm_margin=0) { @@ -222,6 +211,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, softmax_scale, window_size_left, window_size_right, + attention_chunk, softcap, sm_margin); @@ -256,6 +246,118 @@ void set_params_dgrad(Flash_bwd_params ¶ms, params.deterministic = deterministic; } +template +void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (!params.is_e4m3) { + if (params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 + if constexpr (Arch == 90) { + if (params.dv > 256) { + return run_mha_fwd_(params, stream); + } else if (params.dv > 64) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_(params, stream); + } + } + #endif + return run_mha_fwd_(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } + } else { + #ifndef FLASHATTENTION_DISABLE_FP8 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 + if constexpr (Arch == 90) { + if (params.dv <= 128) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + } + #endif + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP8."); + #endif + } +} + void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // HEADDIM_SWITCH(params.d, [&] { // run_mha_fwd_(params, stream); @@ -263,70 +365,12 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { - PAGEDKV_SWITCH(params.page_table, PagedKV, [&] { + PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation - static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV || Split; + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation + static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { - if (!params.is_e4m3) { - if (params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } - } else { - #ifndef FLASHATTENTION_DISABLE_FP8 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP8."); - #endif - } + run_mha_fwd_constexpr(params, stream); }); }); }); @@ -334,33 +378,27 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { +void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { #ifndef FLASHATTENTION_DISABLE_SPLIT // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { - if (params.d <= 64) { - run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else if (params.is_bf16) { - if (params.d <= 64) { - run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } else { - if (params.d <= 64) { - run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); + if (params.dv <= 64) { + run_mha_fwd_combine_(params, stream, enable_pdl); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream, enable_pdl); } } #else @@ -368,17 +406,28 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif } +inline bool get_pagedkv_tma(Flash_fwd_params const& params) { + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // This needs to match the kernel configs + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); + int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); + // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, + // at least for MLA. + return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; +} + inline bool get_pack_gqa(Flash_fwd_params const& params) { - // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size. + // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. - if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; } + if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -392,10 +441,10 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); @@ -405,9 +454,13 @@ inline int get_num_splits(Flash_fwd_params const& params) { : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128); - // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, - // params.num_sm, num_n_blocks, 128, params.d_rounded); + int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); + // Always enable PackGQA for Split + // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. + // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending + // that batch = 1. + int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } @@ -449,6 +502,159 @@ inline int round_up_headdim(int head_size) { return 256; } +inline int round_up_headdimv(int head_size) { + if (head_size <= 64) { return 64; } + if (head_size <= 96) { return 96; } + if (head_size <= 128) { return 128; } + if (head_size <= 192) { return 192; } + if (head_size <= 256) { return 256; } + return 512; +} + +// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available +at::Tensor +mha_fwd_get_scheduler_metadata( + int64_t batch_size, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + int64_t num_heads, + int64_t num_heads_k, + int64_t headdim, + int64_t headdim_v, + at::ScalarType qkv_dtype, + at::Tensor seqused_k, // b + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional leftpad_k_, // b + std::optional page_size, + int64_t max_seqlen_k_new, // 0 means we're not appending new KV + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + bool has_softcap, + int64_t num_splits, + std::optional pack_gqa_, + int64_t sm_margin) { + + TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // Reset the parameters + Flash_fwd_params params{}; + params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; + params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; + params.b = batch_size; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.h = num_heads; + params.h_k = num_heads_k; + params.d = headdim; + params.dv = headdim_v; + params.d_rounded = round_up_headdim(headdim); + params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); + params.seqlen_knew = max_seqlen_k_new; + + bool const is_varlen_q = cu_seqlens_q_.has_value(); + params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; + bool const is_varlen_k = cu_seqlens_k_.has_value(); + params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; + params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; + params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; + params.seqused_k = seqused_k.data_ptr(); + params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; + params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; + if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } + if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { + is_causal = false; + } + } + if (is_causal) { window_size_right = 0; } + + params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; + if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } + if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } + if (attention_chunk > 0) { + window_size_left = std::min(window_size_left, attention_chunk - 1); + window_size_right = std::min(window_size_right, attention_chunk - 1); + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.attention_chunk = attention_chunk; + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; + params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; + params.softcap = has_softcap ? 1.0f : 0.0f; + + params.page_size = page_size.has_value() ? page_size.value() : 1; + params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); + + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); + + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + bool is_varlen = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()}; + + auto opts = seqused_k.options(); + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::empty( + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + opts.dtype(torch::kInt32)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + if (scheduler_needs_semaphore) { + if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; + } else { + params.tile_count_semaphore = nullptr; + } + } + + if (use_prepare_varlen) { + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); + int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); + int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + return tile_count_semaphore; +} + // b: batch_size // b_k: batch_size_k // s_q: seqlen_q @@ -457,39 +663,42 @@ inline int round_up_headdim(int head_size) { // h: num_heads // h_k: num_heads_k // d: head_size -std::vector -mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &cu_seqlens_k_new_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, +std::tuple +mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional &page_table_, // (b_k, max_num_pages_per_seq) - std::optional &kv_batch_idx_, // b. indices to index into the KV cache - std::optional &leftpad_k_, // b - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &q_descale_, // (b, h_k), not (b, h) - std::optional &k_descale_, // (b, h_k) - std::optional &v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - int sink_token_length, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin + int64_t sm_margin ) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -539,11 +748,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } - // This is what we will template on - bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); - #ifdef FLASHATTENTION_DISABLE_VARLEN - TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); - #endif auto const sizes = q.sizes(); const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; @@ -551,6 +755,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); @@ -558,21 +763,40 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } if (!kv_batch_idx_.has_value()) { TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || + (head_size <= 64 && head_size_v <= 512), + "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " + "or (Q/K <= 64 and V <= 512)."); + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { + // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA + if ((head_size <= 64 || head_size > 128) || !paged_KV) { + is_causal = false; + } + } if (is_causal) { window_size_right = 0; } - // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true. - // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM. - is_causal = window_size_left < 0 && window_size_right == 0; if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); @@ -583,15 +807,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } @@ -608,8 +832,22 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_SHAPE(seqused_k, batch_size); } + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + // This is what we will template on + bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); + #ifdef FLASHATTENTION_DISABLE_VARLEN + TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); + #endif + int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; @@ -620,16 +858,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { - out = torch::empty_like(q, opts.dtype(out_type)); + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -661,13 +902,17 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq softmax_scale, window_size_left, window_size_right, + attention_chunk, softcap, sm_margin); params.total_q = total_q; params.total_k = total_k; - params.sink_token_length = sink_token_length; params.b_k = batch_size_k; - + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; + if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma + params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); + } if (paged_KV) { params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); @@ -675,10 +920,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - - if (k_new_.has_value()) { + if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); @@ -702,10 +944,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; @@ -725,13 +967,80 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - params.leftpad_k = static_cast(leftpad_k.data_ptr()); + params.pagedkv_tma = get_pagedkv_tma(params); + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + + // This needs to be set after get_num_splits + at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic + // We don't use the persistent scheduler if Split and not Varlen + bool const scheduler_needs_semaphore = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); + params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); + if (scheduler_metadata_.has_value()) { + at::Tensor scheduler_metadata = scheduler_metadata_.value(); + CHECK_DEVICE(scheduler_metadata); + CHECK_SHAPE(scheduler_metadata, metadata_size); + CHECK_CONTIGUOUS(scheduler_metadata); + TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); + tile_count_semaphore = scheduler_metadata; + } else { + tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); + } + if (scheduler_needs_semaphore && !use_prepare_varlen) { + tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing + } + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later + } + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } } if (rotary_cos_.has_value()) { @@ -756,6 +1065,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } } else { params.rotary_dim = 0; } @@ -772,12 +1088,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; @@ -790,18 +1106,6 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } - at::Tensor tile_count_semaphore; - // We don't use the persistent scheduler if Split and not Varlen - bool const persistent_scheduler = params.arch >= 90 - ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) - : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (persistent_scheduler) { - tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); - } else { - params.tile_count_semaphore = nullptr; - } - if (q_type == at::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); @@ -845,10 +1149,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA - TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table || params.num_splits > 1, "This flash attention build does not support pack_gqa."); + TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV - TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV."); + TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); @@ -864,12 +1168,18 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 // and seqlen = total_q, and don't need to dispatch to Varlen there. + // However, with dynamic split, each row needs to know which batch it belongs to + // to read the number of splits, so we just use the varlen version of combine kernel. // if (is_varlen_q && !seqused_q_.has_value()) { - if (is_varlen_q) { - params.b = 1; - params.seqlen_q = total_q; - } - run_mha_fwd_combine(params, stream); + // if (is_varlen_q) { + // params.b = 1; + // params.seqlen_q = total_q; + // } + // This will zero out the semaphore if needed + run_mha_fwd_combine(params, stream, true /*enable_pdl*/); + } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { + // need to zero out the semaphore in this case + tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. @@ -881,8 +1191,53 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq return {out, softmax_lse, out_accum, softmax_lse_accum}; } +#ifdef FLASHATTENTION_DISABLE_BACKWARD +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); +} +#else +template +void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { + if (!params.is_bf16) { + #ifndef FLASHATTENTION_DISABLE_FP16 + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + #else + TORCH_CHECK(false, "This flash attention build does not support FP16."); + #endif + } else { + #ifndef FLASHATTENTION_DISABLE_HDIM64 + if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM96 + if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM128 + if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM192 + if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } + #endif + #ifndef FLASHATTENTION_DISABLE_HDIM256 + if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } + #endif + } +} + void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { - #ifndef FLASHATTENTION_DISABLE_BACKWARD // FP16_SWITCH(!params.is_bf16, [&] { // HEADDIM_SWITCH(params.d, [&] { // run_mha_bwd_(params, stream); @@ -890,47 +1245,11 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // }); ARCH_SWITCH(params.arch, Arch, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { - if (!params.is_bf16) { - #ifndef FLASHATTENTION_DISABLE_FP16 - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - #else - TORCH_CHECK(false, "This flash attention build does not support FP16."); - #endif - } else { - #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } - #endif - #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } - #endif - } + run_mha_bwd_constexpr(params, stream); }); }); - #endif } +#endif // b: batch_size @@ -939,30 +1258,30 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // h: num_heads // h_k: num_heads_k // d: head_size -std::vector mha_bwd( - const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q - std::optional &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - std::optional &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k - std::optional &cu_seqlens_q_, // b+1 - std::optional &cu_seqlens_k_, // b+1 - std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - std::optional max_seqlen_k_, - float const softmax_scale, +std::tuple mha_bwd( + at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k + std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + std::optional max_seqlen_k_, + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - int const sink_token_length, - float const softcap, - bool const deterministic, - int const sm_margin) { + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool deterministic, + int64_t sm_margin +) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -1017,13 +1336,19 @@ std::vector mha_bwd( int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int const num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); int const max_headdim = get_max_headdim(); - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + double softmax_scale = 1.0 / sqrt(double(head_size)); + if (softmax_scale_.has_value()) { + softmax_scale = softmax_scale_.value(); + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } @@ -1034,7 +1359,8 @@ std::vector mha_bwd( is_causal = window_size_left < 0 && window_size_right == 0; int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; - int const head_size_rounded = round_up_headdim(head_size); + int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); + int const head_size_v_rounded = head_size_rounded; // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) @@ -1063,20 +1389,20 @@ std::vector mha_bwd( if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); - CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); + CHECK_SHAPE(dout, total_q, num_heads, head_size_v); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!is_varlen_k) { CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } @@ -1126,9 +1452,9 @@ std::vector mha_bwd( CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); if (!is_varlen_k) { - CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); } else { - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); } } else { dv = torch::empty_like(v); @@ -1158,10 +1484,10 @@ std::vector mha_bwd( if (num_heads_k != num_heads) { // MQA / GQA if (!is_varlen) { dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat)); } else { dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat)); } } @@ -1187,13 +1513,15 @@ std::vector mha_bwd( softmax_scale, window_size_left, window_size_right, + 0, // attention_chunk softcap, deterministic, sm_margin); params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); - params.sink_token_length = sink_token_length; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); @@ -1201,9 +1529,9 @@ std::vector mha_bwd( at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); params.dq_semaphore = dq_semaphore.data_ptr(); if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } @@ -1231,9 +1559,9 @@ std::vector mha_bwd( return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } -std::vector -mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size - const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads +std::tuple +mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size + at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads std::optional out_, // batch_size x seqlen x num_heads x head_size std::optional out_dtype_ ) { @@ -1258,7 +1586,6 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); @@ -1307,7 +1634,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; - params.d = head_size; + params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); @@ -1319,10 +1646,11 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); + params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; if (seqlen > 0 && batch_size > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd_combine(params, stream); + run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } at::Tensor out_padded = out; @@ -1334,9 +1662,100 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x return {out, softmax_lse}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "FlashAttention"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("fwd_combine", &mha_combine, "Combine partial attention outputs"); +TORCH_LIBRARY(flash_attn_3, m) { + m.def("fwd(" + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor(k_new!)? k_new = None," + "Tensor(v_new!)? v_new = None," + "Tensor? q_v = None," + "Tensor(out!)? out = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "Tensor? page_table = None," + "Tensor? kv_batch_idx = None," + "Tensor? leftpad_k = None," + "Tensor? rotary_cos = None," + "Tensor? rotary_sin = None," + "Tensor? seqlens_rotary = None," + "Tensor? q_descale = None," + "Tensor? k_descale = None," + "Tensor? v_descale = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "float softcap = 0.0," + "bool is_rotary_interleaved = False," + "Tensor? scheduler_metadata = None," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); + m.def("bwd(" + "Tensor dout," + "Tensor q," + "Tensor k," + "Tensor v," + "Tensor out," + "Tensor softmax_lse," + "Tensor(dq!)? dq = None," + "Tensor(dk!)? dk = None," + "Tensor(dv!)? dv = None," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? seqused_q = None," + "Tensor? seqused_k = None," + "int? max_seqlen_q = None," + "int? max_seqlen_k = None," + "float? softmax_scale = None," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "float softcap = 0.0," + "bool deterministic = False," + "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("fwd_combine(" + "Tensor out_partial," + "Tensor lse_partial," + "Tensor(out!)? out = None," + "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); + m.def("get_scheduler_metadata(" + "int batch_size," + "int max_seqlen_q," + "int max_seqlen_k," + "int num_heads," + "int num_heads_k," + "int headdim," + "int headdim_v," + "ScalarType qkv_dtype," + "Tensor seqused_k," + "Tensor? cu_seqlens_q = None," + "Tensor? cu_seqlens_k = None," + "Tensor? cu_seqlens_k_new = None," + "Tensor? seqused_q = None," + "Tensor? leftpad_k = None," + "int? page_size = None," + "int max_seqlen_k_new = 0," + "bool is_causal = False," + "int window_size_left = -1," + "int window_size_right = -1," + "int attention_chunk = 0," + "bool has_softcap = False," + "int num_splits = 0," + "bool? pack_gqa = None," + "int sm_margin = 0) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { + m.impl("fwd", &mha_fwd); + m.impl("bwd", &mha_bwd); + m.impl("fwd_combine", &mha_combine); + m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); } diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5f1e4899c9..a435e7a627 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -7,10 +7,11 @@ # isort: off # We need to import the CUDA kernels after importing torch -import flash_attn_3_cuda +import flash_attn_3._C # Registers operators with PyTorch # isort: on +flash_attn_3_cuda = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -22,6 +23,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -35,19 +37,21 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, rotary_interleaved=True, + scheduler_metadata=None, num_splits=1, pack_gqa=None, - sm_margin=0): - assert sink_token_length == 0, "sink_token_length not supported yet" + sm_margin=0, + ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ @@ -58,12 +62,14 @@ def _flash_attn_forward( maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + seqlens_rotary = maybe_contiguous(seqlens_rotary) out, softmax_lse, *rest = flash_attn_3_cuda.fwd( q, k, v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -77,6 +83,7 @@ def _flash_attn_forward( leftpad_k, rotary_cos, rotary_sin, + seqlens_rotary, q_descale, k_descale, v_descale, @@ -84,14 +91,15 @@ def _flash_attn_forward( causal, window_size[0], window_size[1], - sink_token_length, + attention_chunk, softcap, rotary_interleaved, + scheduler_metadata, num_splits, pack_gqa, sm_margin, ) - return (out, softmax_lse, *rest) + return out, softmax_lse, *rest def _flash_attn_backward( @@ -113,12 +121,10 @@ def _flash_attn_backward( softmax_scale, causal, window_size=(-1, -1), - sink_token_length=0, softcap=0.0, deterministic=False, sm_margin=0, ): - assert sink_token_length == 0, "sink_token_length not supported yet" # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( @@ -141,7 +147,6 @@ def _flash_attn_backward( causal, window_size[0], window_size[1], - sink_token_length, softcap, deterministic, sm_margin, @@ -158,10 +163,12 @@ def forward( causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, + sm_margin=0, + return_softmax=False, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -174,30 +181,42 @@ def forward( num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert num_heads_k * 2 + num_heads_q == qkv.shape[2] q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) - out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( + out, softmax_lse, *rest = _flash_attn_forward( q, k, v, + None, None, # k_new, v_new + None, # qv + None, # out + None, None, None, # cu_seqlens_q/k/k_new + None, None, # seqused_q/k + None, None, # max_seqlen_q/k + None, None, None, # page_table, kv_batch_idx, leftpad_k, + None, None, None, # rotary_cos/sin, seqlens_rotary + q_descale, k_descale, v_descale, softmax_scale, causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, sink_token_length=sink_token_length, + window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, + sm_margin=sm_margin, ) - ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() - # return out, softmax_lse - return out + ctx.sm_margin = sm_margin + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" if ctx.ndim == 5: qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -215,18 +234,21 @@ def backward(ctx, dout, *args): v, out, softmax_lse, + None, None, # cu_seqlens_q, cu_seqlens_k, + None, None, # sequed_q, sequed_k, + None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, + ctx.sm_margin, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -239,34 +261,37 @@ def forward( v, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, + attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -277,15 +302,16 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, @@ -303,15 +329,14 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -330,23 +355,26 @@ def forward( max_seqlen_k, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, + return_softmax=False, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out cu_seqlens_q, cu_seqlens_k, @@ -356,12 +384,12 @@ def forward( max_seqlen_q, max_seqlen_k, None, None, None, # page_table, kv_batch_idx, leftpad_k, - None, None, # rotary_cos/sin + None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, + attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, @@ -374,15 +402,16 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size - ctx.sink_token_length = sink_token_length + ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin - return out, softmax_lse + return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, @@ -403,15 +432,14 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, - ctx.sink_token_length, ctx.softcap, ctx.deterministic, ctx.sm_margin, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + dq = dq[..., : q.shape[-1]] # We could have padded the head dimension + dk = dk[..., : k.shape[-1]] + dv = dv[..., : v.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -420,10 +448,12 @@ def flash_attn_qkvpacked_func( causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, + sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -465,10 +495,12 @@ def flash_attn_qkvpacked_func( causal, q_descale, k_descale, v_descale, window_size, - sink_token_length, + attention_chunk, softcap, deterministic, num_heads_q, + sm_margin, + return_attn_probs, ) @@ -478,14 +510,16 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -538,14 +572,16 @@ def flash_attn_func( v, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, + attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, + return_attn_probs, ) @@ -555,20 +591,22 @@ def flash_attn_varlen_func( v, cu_seqlens_q, cu_seqlens_k, - seqused_q, - seqused_k, max_seqlen_q, max_seqlen_k, + seqused_q=None, + seqused_k=None, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), - sink_token_length=0, + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, + return_attn_probs=False, ): return FlashAttnVarlenFunc.apply( q, @@ -582,14 +620,16 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, - sink_token_length, + attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, + return_attn_probs, ) @@ -603,6 +643,7 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, + qv=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, @@ -612,15 +653,17 @@ def flash_attn_with_kvcache( cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window - sink_token_length=0, + attention_chunk=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, + scheduler_metadata=None, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication @@ -672,12 +715,13 @@ def flash_attn_with_kvcache( q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.). + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. @@ -710,14 +754,13 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ - assert sink_token_length == 0 assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) out, softmax_lse, *rest = _flash_attn_forward( @@ -726,6 +769,7 @@ def flash_attn_with_kvcache( v_cache, k, v, + qv, None, # out cu_seqlens_q, None, # cu_seqlens_k @@ -739,16 +783,61 @@ def flash_attn_with_kvcache( cache_leftpad, rotary_cos, rotary_sin, + rotary_seqlens, q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size=window_size, - sink_token_length=sink_token_length, + attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out + + +def get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + attention_chunk=0, + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication +): + cache_seqlens = maybe_contiguous(cache_seqlens) + if headdim_v is None: + headdim_v = headdim + scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, + qkv_dtype, + cache_seqlens, + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_leftpad, + page_size, + max_seqlen_k_new, + causal, + window_size[0], window_size[1], + attention_chunk, + has_softcap, + num_splits, + pack_gqa, + sm_margin, + ) + return scheduler_metadata diff --git a/hopper/flash_bwd_kernel_sm80.h b/hopper/flash_bwd_kernel_sm80.h index b4fe26285c..aaec00dbe4 100644 --- a/hopper/flash_bwd_kernel_sm80.h +++ b/hopper/flash_bwd_kernel_sm80.h @@ -133,8 +133,8 @@ class FlashAttnBwdSm80 { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. @@ -155,15 +155,14 @@ class FlashAttnBwdSm80 { // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = collective_mainloop.mma( - params.mainloop, tdKrdK, tdVrdV, threadIdx.x, block_coord, - shared_storage); + bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, + block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { - collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x, block_coord); + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x, block_coord); } else { - collective_epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_bwd_kernel_sm90.h b/hopper/flash_bwd_kernel_sm90.h index 7aa32a8460..b93a021916 100644 --- a/hopper/flash_bwd_kernel_sm90.h +++ b/hopper/flash_bwd_kernel_sm90.h @@ -195,8 +195,8 @@ class FlashAttnBwdSm90 { PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -206,6 +206,8 @@ class FlashAttnBwdSm90 { __syncthreads(); } + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); @@ -213,8 +215,6 @@ class FlashAttnBwdSm90 { if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state(); - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { @@ -224,32 +224,29 @@ class FlashAttnBwdSm90 { auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; - collective_mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, - smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); + mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, + smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); } - collective_mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); + mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); } else if (warp_idx_in_warpgroup == 1) { - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; - collective_mainloop.store_dq(params.mainloop, shared_storage, block_coord); + mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. TiledMmadKV tiled_mma_dKV; PipelineState smem_pipe_read; PipelineState_dO smem_pipe_read_do; - collective_mainloop.mma_init(); + mainloop.mma_init(); scheduler.init_consumer(); int work_idx = 0; @@ -264,18 +261,18 @@ class FlashAttnBwdSm90 { // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); - bool tile_valid = collective_mainloop.mma( + bool tile_valid = mainloop.mma( params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); if (tile_valid) { - collective_epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, - threadIdx.x - NumCopyThreads, block_coord); + epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, + threadIdx.x - NumCopyThreads, block_coord); } else { - collective_epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); } } - collective_epilogue.store_tail(); + epilogue.store_tail(); } } diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index 635228eebc..b6e8810b25 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -49,7 +49,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_O + {seqlen_q, params.dv, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO @@ -93,7 +93,11 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, flash::CollectiveEpilogueBwdGQA >; - using Scheduler = flash::SingleTileScheduler; + using Scheduler = std::conditional_t< + Is_causal && !Varlen, + flash::SingleTileBwdLPTScheduler, + flash::SingleTileScheduler + >; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, @@ -108,8 +112,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), + {seqlen_q, params.dv, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum @@ -120,7 +126,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, 0 /*attention_chunk*/, params.softcap, params.b, params.dq_semaphore, @@ -145,11 +151,18 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + } + }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), params.h, @@ -165,7 +178,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { num_blocks_n, params.h, params.b, 1 /*num_splits*/, params.h / params.h_k, params.seqlen_k, - params.seqlen_q, params.d, sizeof(Element), + params.seqlen_q, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k }; @@ -256,10 +269,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 5b7d9eed65..3e85a0a212 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -3,14 +3,11 @@ #include "flash_fwd_combine_launch_template.h" -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index aaec31e580..0566769800 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -12,6 +12,8 @@ #include #include +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" @@ -40,11 +42,11 @@ class FlashAttnFwdCombine { static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemCopyAtom = std::conditional_t< @@ -98,8 +100,8 @@ class FlashAttnFwdCombine { Stride, _1>>{})); using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); - using SmemLayoutO = Layout, Int, Int>, - Stride, _1, Int>>; + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; // We want each column (kMaxSplits) to be processed by threads in the same warp. // To reduce the number of shuffles, we want as few threads on the same column as possible. @@ -128,38 +130,43 @@ class FlashAttnFwdCombine { static constexpr int SharedStorageSize = sizeof(SharedStorage); - // Device side arguments struct Arguments { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Kernel entry point API struct Params { - ElementPartial const* ptr_O_partial; + ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; - float const* ptr_LSE_partial; + float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; - Element* ptr_O; + Element* const ptr_O; StrideO const stride_O; - float* ptr_LSE; + float* const ptr_LSE; StrideLSE const stride_LSE; cutlass::FastDivmod seqlen_divmod, head_divmod; - int const* cu_seqlens = nullptr; - int const* seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + int* const semaphore_to_reset = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -180,7 +187,11 @@ class FlashAttnFwdCombine { args.stride_LSE, cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), args.cu_seqlens, - args.seqused + args.seqused, + args.num_splits_dynamic_ptr, + args.varlen_batch_idx_ptr, + args.semaphore_to_reset, + }; } @@ -195,17 +206,30 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; - int const batch = !Varlen ? 0 : blockIdx.y; - int const num_splits = get<1>(params.shape_LSE_partial); + int const k_block = blockIdx.y; + int const maybe_virtual_batch = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); + + if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { + cutlass::arch::wait_on_dependent_grids(); + *params.semaphore_to_reset = 0; + } + if (num_splits <= 1) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; - int max_idx = seqlen * get<2>(params.shape_LSE_partial) * get<3>(params.shape_LSE_partial); + int max_idx = seqlen * get<2>(params.shape_LSE_partial); + if constexpr (Varlen) { + if (m_block * kBlockM >= max_idx) { return; } + } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem - Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial)); // (num_splits, seqlen, head, batch) + Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), + select<1, 0, 2, 3>(params.shape_LSE_partial), + select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); GmemTiledCopyLSE gmem_tiled_copy_LSE; auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); @@ -216,19 +240,20 @@ class FlashAttnFwdCombine { // Repeat the partitioning with identity layouts Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); + cutlass::arch::wait_on_dependent_grids(); + #pragma unroll for (int m = 0; m < size<2>(tLSEcLSE); ++m) { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { - int m_idx, bidh, bidb; + int m_idx, bidh; if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; } - Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh, bidb); + Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); #pragma unroll for (int s = 0; s < size<1>(tLSEcLSE); ++s) { int si = get<0>(tLSEcLSE(_0{}, s, _0{})); @@ -254,33 +279,32 @@ class FlashAttnFwdCombine { Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); - Tensor tObidb = make_tensor(make_shape(size<1>(tOcO))); Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { int mi = get<0>(tOcO(_0{}, m, _0{})); int idx = m_block * kBlockM + mi; if constexpr (!Varlen) { - tObidb[m] = params.head_divmod.divmod(tObidh(m), params.seqlen_divmod.divmod(tOmidx(m), idx)); + tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); } else { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); - tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); if (idx >= max_idx) { - tObidb[m] = -1; + tObidh[m] = -1; } } Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); if constexpr (!(Is_even_K)) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } } Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); @@ -289,8 +313,8 @@ class FlashAttnFwdCombine { Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { - Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}, _0{}).layout()); + if (tObidh(m) >= 0) { + Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { @@ -358,26 +382,35 @@ class FlashAttnFwdCombine { // Store the scales exp(lse - lse_logsum) back to smem cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); - // Step 5: store final LSE back to gmem - auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + // Store max_valid_split to smem #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh, bidb; - if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh; + if constexpr (!Varlen) { + bidh = params.seqlen_divmod.divmod(m_idx, idx); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh) = lse_sum(m); } - // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); } - if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } } } @@ -408,7 +441,7 @@ class FlashAttnFwdCombine { #pragma unroll for (int m = 0; m < size<1>(tOrOpartial); ++m) { - if (tObidb(m) >= 0 && scale(m) > 0.f) { + if (tObidh(m) >= 0 && scale(m) > 0.f) { #pragma unroll for (int k = 0; k < size<2>(tOrOpartial); ++k) { if (Is_even_K || tOpO(k)) { @@ -427,20 +460,21 @@ class FlashAttnFwdCombine { // Step 7: Write the final O to gmem Tensor rO = make_tensor_like(tOrO); flash::convert_type_out(tOrO, rO); - auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial); - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { - if (tObidb(m) >= 0) { + if (tObidh(m) >= 0) { #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (Is_even_K || tOpO(k)) { - cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m), tObidb(m))); + cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); } } } diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 33e66c21f8..a2ff25dcd5 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -9,6 +9,7 @@ #include "cutlass/cutlass.h" #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 #include "cutlass/device_kernel.h" // For device_kernel +#include "cutlass/kernel_launch.h" // For kernel_launch #include "static_switch.h" #include "flash.h" @@ -16,15 +17,16 @@ using namespace cute; -template -void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { - using TileShape_MK = cute::Shape, Int>; +template +void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { + using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; + using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; + IsEvenK, Varlen, Element, ElementPartial, ArchTag>; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial static_cast(params.softmax_lseaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial @@ -33,42 +35,46 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); - int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); + int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); + dim3 grid_m(num_blocks_m, num_blocks_k, params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } -template -void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { +template +void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32"); - static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32); - BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { - if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. - if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); - return; + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); + ARCH_SWITCH(params.arch, Arch, [&] { + BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { + if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. + if (params.num_splits <= 16) { + run_flash_fwd_combine(params, stream, enable_pdl); + return; + } } - } - if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); - } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); - } else { - run_flash_fwd_combine(params, stream); - } + if (params.num_splits <= 32) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 64) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else if (params.num_splits <= 128) { + run_flash_fwd_combine(params, stream, enable_pdl); + } else { + run_flash_fwd_combine(params, stream, enable_pdl); + } + }); }); } diff --git a/hopper/flash_fwd_kernel_sm80.h b/hopper/flash_fwd_kernel_sm80.h index a2f550478a..b308d2d1b8 100644 --- a/hopper/flash_fwd_kernel_sm80.h +++ b/hopper/flash_fwd_kernel_sm80.h @@ -151,8 +151,8 @@ class FlashAttnFwdSm80 { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. @@ -187,25 +187,24 @@ class FlashAttnFwdSm80 { get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.store_kv_new( + bool tile_new_valid = mainloop.store_kv_new( params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { __syncthreads(); } } - bool tile_valid = collective_mainloop.mma( + bool tile_valid = mainloop.mma( params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, - threadIdx.x, block_coord); + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, + threadIdx.x, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - collective_epilogue.template store_zero(params.epilogue, threadIdx.x, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index e5411042dc..47b3817cd2 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -14,6 +14,8 @@ #include #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/grid_dependency_control.h" + #include "seqlen.h" #include "utils.h" #include "softmax.h" @@ -35,22 +37,24 @@ class FlashAttnFwdSm90 { static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; - static constexpr bool PagedKV = CollectiveMainloop::PagedKV; static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMma0 = typename CollectiveMainloop::TiledMma0; - using TiledMma1 = typename CollectiveMainloop::TiledMma1; + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -68,8 +72,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -88,7 +92,7 @@ class FlashAttnFwdSm90 { static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _1> { union { struct { cute::array padding_; @@ -98,9 +102,9 @@ class FlashAttnFwdSm90 { typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; @@ -176,7 +180,7 @@ class FlashAttnFwdSm90 { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; @@ -205,6 +209,9 @@ class FlashAttnFwdSm90 { if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); } @@ -216,12 +223,21 @@ class FlashAttnFwdSm90 { if constexpr (Use_TMA_KV) { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; - pipeline_params_k.num_consumers = NumMmaThreads; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; } else { - pipeline_params_k.consumer_arv_count = NumMmaThreads; + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; pipeline_params_k.producer_arv_count = NumProducerThreads; } + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } + } + MainloopPipelineK pipeline_k = [&] { if constexpr (Use_TMA_KV) { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); @@ -234,9 +250,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); } } else { PipelineParamsV pipeline_params_v; @@ -248,7 +264,6 @@ class FlashAttnFwdSm90 { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } }(); - static_assert(is_same_v); // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then // the producer WG will read from pipeline_vt and write to pipeline_v. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. @@ -256,11 +271,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{}); + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); } else { - pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k); + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); } }(); @@ -272,10 +287,13 @@ class FlashAttnFwdSm90 { pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; pipeline_params_kv_new.num_consumers = NumMmaThreads; auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue; + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -285,18 +303,18 @@ class FlashAttnFwdSm90 { __syncthreads(); } + TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); + if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); // The pipelines for AppendKV and main attention are different, since e.g. main attention - // might use cp.async to load KV (if PagedKV) while AppendKV always uses TMA to load + // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load // KV_new. Since the pipeline states are different, we have to manually sync to make // sure the two pipelines don't race when accessing smem_k and smem_v. PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); int work_idx = 0; - - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; if constexpr (SingleProducerWarp) { @@ -304,6 +322,8 @@ class FlashAttnFwdSm90 { } if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } + cutlass::arch::wait_on_dependent_grids(); + // Load Q, K, V for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); @@ -313,13 +333,14 @@ class FlashAttnFwdSm90 { SeqlenInfo_t seqlen_info{ get<2>(block_coord) /*bidb*/, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.load_kv_new( + bool tile_new_valid = mainloop.load_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx); if (tile_new_valid) { @@ -332,16 +353,15 @@ class FlashAttnFwdSm90 { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; // pipeline_vt won't be used if we don't need to transpose V. - collective_mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, + mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); } - collective_mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); + mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); - TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. - TiledMma1 tiled_mma1; + TiledMmaPV tiled_mma_pv; PipelineState smem_pipe_read; PipelineState smem_pipe_read_new; @@ -349,38 +369,27 @@ class FlashAttnFwdSm90 { // (like in Cutlass's gemm) because the read and release pipeline states are always the same. scheduler.init_consumer(); - collective_mainloop.mma_init(); + mainloop.mma_init(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); - work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { - // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); - float softmax_scale_log2 = params.mainloop.softmax_scale_log2; - // If there's tanh softcap, the scaling will be done before tanh. + // get_next_work will be called before the epilogue + ) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); int const bidb = get<2>(block_coord); - if constexpr (Is_FP8 && !Has_softcap) { - int const bidh = get<1>(block_coord); - int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; - float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; - float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; - softmax_scale_log2 *= q_descale * k_descale; - } - flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); - SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), - !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), + !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, + params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { - bool tile_new_valid = collective_mainloop.store_kv_new( + bool tile_new_valid = mainloop.store_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { @@ -395,22 +404,51 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } - bool tile_valid = collective_mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + // If there's tanh softcap, the scaling will be done before tanh. + float softmax_scale_log2 = params.mainloop.softmax_scale_log2; + if constexpr (Is_FP8 && !Has_softcap) { + int const bidh = get<1>(block_coord); + int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; + float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; + float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; + softmax_scale_log2 *= q_descale * k_descale; + } + flash::Softmax softmax(softmax_scale_log2); + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { // mma_pv might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { + tile_valid = mainloop.mma_pv( + params.mainloop, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); + } + } + // Do this here before the epilogue so that the next tile is ready to go. + work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); + if constexpr (Split && Varlen) { + if (!work_tile_info.is_valid(params.scheduler)) { // Last tile + cutlass::arch::launch_dependent_grids(); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, - threadIdx.x - MmaThreadOffset, block_coord); + epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, + threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. - // If Split, we don't have to write 0 to O if the mha_combine kernel is used, since it will - // not use the value of O if LSE is -inf. - collective_epilogue.template store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); - // collective_epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); + epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); } } - collective_epilogue.store_tail(); + epilogue.store_tail(); } } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 16701f160d..d48a4fd956 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -10,6 +10,7 @@ #include "cutlass/device_kernel.h" // For device_kernel #include #include "cutlass/cluster_launch.hpp" +#include "cutlass/kernel_launch.h" #include "static_switch.h" #include "flash.h" @@ -23,8 +24,8 @@ using namespace cute; -template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -35,28 +36,31 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + static constexpr bool LPT = Is_causal || Is_local; + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -67,7 +71,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better // since we'll avoid launching a bunch of thread blocks that immediately exit. // On Sm80, noncausal persistent seems a bit slower. - using Scheduler = std::conditional_t= 90 ? (Split && !Varlen) : !((Is_causal && !Varlen) || (Varlen && Split)), SchedulerSingleTile, SchedulerPersistent>; + static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); + using Scheduler = std::conditional_t; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90_or_later>, @@ -89,16 +94,19 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), - {!PagedKV ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, - params.d, params.h_k, !PagedKV ? batch_k : params.num_pages}, // shape_K + {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, + params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), + params.dv, // headdim_v v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new static_cast(params.vnew_ptr), {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv static_cast(params.rotary_cos_ptr), {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter {params.rotary_dim / 2, _1{}}, // stride_rotary_cos @@ -107,31 +115,31 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { params.is_rotary_interleaved, params.page_table, // if page_size is not set, avoid dividing by zero - {params.kv_batch_idx ? params.b_k : params.b, !PagedKV ? 0 : params.seqlen_k / params.page_size}, // shape_page_table + {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table {params.page_table_batch_stride, _1{}}, // stride_page_table params.scale_softmax, params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, {params.q_descale_batch_stride, params.q_descale_head_stride}, {params.k_descale_batch_stride, params.k_descale_head_stride}, {params.v_descale_batch_stride, params.v_descale_head_stride}, - params.window_size_left, params.window_size_right, params.sink_token_length, + params.window_size_left, params.window_size_right, params.attention_chunk, params.softcap, params.num_splits, params.kv_batch_idx, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, - params.leftpad_k, + params.leftpad_k, params.seqlens_rotary }; typename CollectiveEpilogue::Arguments epilogue_args { - static_cast(!Split ? params.o_ptr : params.oaccum_ptr), - {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O - {!Split ? params.o_row_stride : params.oaccum_row_stride, - _1{}, - !Split ? params.o_head_stride : params.oaccum_head_stride, - !is_varlen_q ? (!Split ? params.o_batch_stride : params.oaccum_batch_stride) : 0, - !Split ? 0 : params.oaccum_split_stride}, // stride_O - static_cast(!Split ? params.softmax_lse_ptr : params.softmax_lseaccum_ptr), - {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, !Split ? 0 : params.h * seqlen_q * batch_q}, // stride_LSE + static_cast(params.o_ptr), + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O + static_cast(params.oaccum_ptr), + {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial + static_cast(params.softmax_lse_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE + static_cast(params.softmax_lseaccum_ptr), + {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial params.h_k, params.cu_seqlens_q, params.seqused_q }; @@ -143,10 +151,19 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, sizeof(Element), - params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q + params.seqlen_k, params.d, params.dv, sizeof(Element), + params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, + params.num_splits_dynamic_ptr, + params.num_m_blocks_ptr, + params.varlen_batch_idx_ptr, + params.num_nheads_in_l2_ptr }; + if (Varlen && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); + CHECK_CUDA_KERNEL_LAUNCH(); + } + int device; CHECK_CUDA(cudaGetDevice(&device)); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ @@ -174,29 +191,33 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } - kernel<<>>(kernel_params); + // kernel<<>>(kernel_params); + cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, + Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; - using T_out = std::conditional_t, float>; + using T_out = std::conditional_t; CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; - - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu new file mode 100644 index 0000000000..1d810c015e --- /dev/null +++ b/hopper/flash_prepare_scheduler.cu @@ -0,0 +1,250 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#include +#include "cutlass/fast_math.h" +#include "cutlass/barrier.h" +#include "cutlass/arch/barrier.h" + +#include "cutlass/arch/grid_dependency_control.h" + +#include "flash.h" + +#include "static_switch.h" + +namespace flash { + +// Sort in descending order +template +struct PrepareSortOp +{ + __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) + { + return lhs > rhs; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template +__global__ void prepare_varlen_num_blocks_kernel( + int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, + int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, + int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, + int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, + cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, + int* const tile_count_semaphore, + int* const num_m_blocks_ptr, + int* const num_splits_dynamic_ptr, + int* const varlen_batch_idx_ptr, + // int* const num_n_blocks_ptr, + int* const num_nheads_in_l2_ptr, + bool enable_pdl, + bool is_causal, + bool packgqa, + int max_kvblocks_in_l2) { + + static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; + static constexpr int kSmemSize = 1; + static constexpr int BLOCK_DIM_X = NumWarps * 32; + static constexpr int ITEMS_PER_THREAD = 1; + static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); + using BlockMergeSort = cub::BlockMergeSort; + + __shared__ int total_blocks_smem[kSmemSize]; + + // Allocate shared memory for BlockMergeSort operations + __shared__ typename BlockMergeSort::TempStorage temp_storage; + + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } + + if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } + __syncthreads(); + + if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } + + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + + auto get_num_m_blocks = [&](int batch_idx) { + int seqlen; + if (seqused_q) { + seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; + } else if (cu_seqlens_q) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_q_static; + } + if(packgqa) { seqlen *= qhead_per_khead; } + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; + }; + + auto get_num_n_blocks = [&](int batch_idx) { + int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; + int seqlen; + if (seqused_k) { + seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; + } else if (cu_seqlens_k) { + int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = seqlen_k_static; + } + int seqlen_new; + if (cu_seqlens_k_new) { + int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; + int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); + seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; + } else { + seqlen_new = seqlen_k_new_static; + } + // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } + seqlen = seqlen - leftpad_k + seqlen_new; + return batch_idx < num_batch && lane < kNumBatchPerWarp + ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; + }; + + int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; + int batch_cta_idx_offset = int(blockIdx.x) * 992; + int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; + int batch_idx = lane + bidb_start; + int num_m_blocks = get_num_m_blocks(batch_idx); + int num_n_blocks = get_num_n_blocks(batch_idx); + + auto get_nheads_in_l2 = [&](int n_blocks) { + int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 + : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 + : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 + : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 + : 1; + if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } + return min(nheads_in_l2, num_head); + }; + + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; + } else { + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); + } + + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } + + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + batch_idx = batch_cta_idx_offset + threadIdx.x; + if (batch_idx < num_batch && threadIdx.x < 992) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; + varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } + } + +} + +} // flash + +void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, + int blockM, int blockN, bool enable_pdl) { + int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); + int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 + int num_ctas = cutlass::ceil_div(params.b, 31 * 32); + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice + int const element_size = params.is_e4m3 ? 1 : 2; + int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); + int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; + BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { + NUM_WARP_SWITCH(num_warps, NumWarps, [&] { + flash::prepare_varlen_num_blocks_kernel<<>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.varlen_batch_idx_ptr, + // params.num_n_blocks_ptr, + params.num_nheads_in_l2_ptr, + enable_pdl, + params.is_causal, + packgqa, + max_kvblocks_in_l2); + }); + }); +} diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index e741c13826..b91a5b128f 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -38,7 +38,7 @@ KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif """ @@ -46,8 +46,8 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif """ @@ -85,6 +85,7 @@ class Kernel: sm: int dtype: str head_dim: int + head_dim_v: int split: bool paged_kv: bool softcap: bool @@ -98,14 +99,15 @@ def template(self) -> str: # Always enable PackGQA for PagedKV or Split to reduce compilation packgqa = self.packgqa or self.paged_kv or self.split return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( - ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() ) else: # Always enable PackGQA for Sm8x to reduce compilation return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() ) @@ -117,13 +119,13 @@ def template(self) -> str: ) else: return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: @@ -133,20 +135,32 @@ def get_all_kernels() -> List[Kernel]: if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): continue if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 192: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=256, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): if sm < 90: continue - kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm] + # Same hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) + # Different hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 8e7b4a314d..031ea44a0b 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -22,9 +22,20 @@ inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, in // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. -inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { +inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split - if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + // However, in the case of super long seqlen where each head of KV doesn't even fit into + // L2 (we assume that L2 size is 50MB), we want to split. + if (total_mblocks >= 0.8f * num_SMs) { + int const size_l2 = 50 * 1024 * 1024; + // Only split if there are enough queries to go over the KV at least twice + // Don't split if causal + if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { + return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); + } else { + return 1; + } + } // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if (num_n_blocks <= 4) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); @@ -32,7 +43,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n std::vector efficiency; efficiency.reserve(max_splits); for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float n_waves = float(total_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu index 18879eff6e..affc7a4dd9 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu index 35c0ad78fd..7e13614bfe 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu index 7a39869a00..670041341b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu index fb7ba5caed..f315fbb454 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu index 296ec9e91f..bde3024a4a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu index 8cffb6de83..2724463e62 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu index 12d564ce36..a38a1d5cf3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu index 845b1fa5d0..284eeba182 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu index 25fbfda38d..0c40ddba8f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu new file mode 100644 index 0000000000..4fb8f71d01 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM128 +template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu index 1130ca747d..cc89c4d5d2 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu index 502bc1d177..3a236b712c 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu index 537e42ba56..8449104c5a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu index 2255e7949e..b152b90bab 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu index 086f55b358..8cc4fed173 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu index 54590eebbc..1db3f1e6d8 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu index af322d1d15..9b3e294f1b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu index 3e83398e7f..07bd687fc3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu index 3f917d26ab..5f44833b10 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu index 87c78f2892..9f95ca29f6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu index e56b64c3d9..ad97737d4f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu index 8202bfadde..d77d37ec04 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu index ee7439b277..ae05c7ce5f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu index 812239ef50..bc52a9f356 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu index 74e52315bd..480d485d06 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu index fe0bff6a1d..d3da5f4e66 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu index 55df1a6663..1c1c2d8207 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu index 03a9c61e40..371d933e3e 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu index 67ba153c60..7491148dcd 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu index 9f7bcec9ed..d04159a62a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu index 7116702f3f..28ad6c1496 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu index 04f18ac0fa..7afb267e3e 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu index c7c7c9e69f..69758584cb 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu index b4ea8bc330..3be45956bb 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu index ec99965c92..698095dad6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu index d1dd964523..16d443a9ad 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu index 83274ca3fd..1e8f6af71b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu index 80e9eb0e2c..4ec6888611 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu index fbbc273b7e..670b5952d9 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu index f4f4829f33..b9778dc92e 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu index c768a89fdf..446e917c79 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu index 89c2db39e6..fd62a2c543 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu index 5b87286aef..0a397f4acf 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu index 7506097821..4d3c553e29 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu index d3b7b0f87b..77621846ff 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu index 4d8625cd6d..7d217ac273 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu index f6f129c550..0b6430abc2 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000..ea1e266f8d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu new file mode 100644 index 0000000000..2d7488fefe --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..8718571e30 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000..f7dfc18fc1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..935f5a0fe6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 0000000000..3f4d858ff5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..54d720efeb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu new file mode 100644 index 0000000000..b9b93af4fc --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu new file mode 100644 index 0000000000..39d9167b9f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000..0f86458012 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000..bd6f4df8f6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu new file mode 100644 index 0000000000..1824b86c64 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000..87dd01725a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000..6594d56012 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..d7dc84ebc1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu new file mode 100644 index 0000000000..b9d6e54cbe --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..a8c47652ec --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000..32d17c7665 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu new file mode 100644 index 0000000000..365017c256 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000..82cfdf040b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000..f3254936a4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu new file mode 100644 index 0000000000..931a6dbf86 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..5c8877a756 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000..1e230ab084 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..03716c8623 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu new file mode 100644 index 0000000000..54c66c9552 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..e5e0ec47db --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu new file mode 100644 index 0000000000..e4411b5db3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu new file mode 100644 index 0000000000..157ed06ddd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000..7ef5adc9e8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu index 96243edf0a..bf8386b829 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu index a51a894588..cbc6f98842 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu index 515d88a11a..d5aa15b5c8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu index e5a154c18d..b8593612df 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu index 2bd860c775..a03514d919 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu index 6e1d803781..df547749e9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu index 942685e148..1ddb191620 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu index d6050520e0..cefffcd216 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu index 7ee500a80e..3d4333b9e1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu index 1f9d8bfd56..35a2abef8c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu index 0313ad1b2c..99e34ac0bf 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu index 8d87eb21f2..ed1cf22d5c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu index 081bb31b12..4527d9a279 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu index a9b5aa0de4..41fcf80017 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu index d465545ef8..704cbcb337 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu index 68c5714553..e0ea082156 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu index e1d656e5ae..a9c00408a8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu index 57d1c73d85..1497e7aa84 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu index 5104d43981..c66ea9baca 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu index cbc61f27e7..a7e472b478 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu index f08ba1459e..9f090aeeda 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu index e413758de8..2205168a67 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu index c8205c1605..2a01898b56 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu index f0db959e0f..888e241a9f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu index 249cae97ff..2a6bde7a39 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu index 14b073deb3..3d315187b2 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu index 8152dbaa6b..3c3d093803 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu index d0b0df0279..4ca103566d 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu index 24f3e128dc..16debf2779 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu index 6eabe0ee26..43c2615718 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu index 5c780da81f..d9d483838f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu index 5a94366017..70543998d9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu index 9815dd1355..c30c7e3b8b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu index 66fc2cb8a6..7ae26e69c9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu index 2ceddd8cae..155b5a539f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu index 4c64bc61c5..3e6173c31c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu index 6ad1a1529c..e1e3191a20 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu index f0ee8c0159..8272ecb76c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu index 4a9583196b..74606c3937 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu index 2b65a88f06..89a58502b3 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu index e324a93267..b13373806a 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu index a8be65709d..1335fad7f2 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu index 1ad82d7edf..18c31bdfc0 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu index 75f53ee4f6..18a5603cf1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu index 09f7652633..4e99c7db02 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu index e5299154c3..82f8204aa6 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu index 364579e1b3..cb851a7711 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu index a5f821becd..ae2871c165 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu index 364bd2b3ae..ed24fbffef 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu index 3d2e337e16..ffca9c7f8f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu index 310c4a5c38..57a06bd6e6 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu index 96f5bbf3ad..ccdcf21e49 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu index 7d3131bd5b..c2bc778776 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu index 7715a52531..6bba953fc6 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu index 686bdfa5c7..25c96174c7 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu index 97fdc0094c..f172239e5b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu index 25a90d3be3..9dde6adb04 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu index 4c91ee5bc0..2317adef8c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu index ef12a584c6..b9b3b74867 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu index e4e746f9d3..c57a5a30ab 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu index 99924af52c..4f59a6aea9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu index 705582b9f2..2c2de1574a 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu index 7e96901205..0dbd062c79 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu index 058eca375d..bee54c702d 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu index 679066d544..c02e683349 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu index e4ce6f9aa1..02b50b98b8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu index 03eff4c6f7..6599de63bb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu index 26df5e592e..a1cdc775cb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu index 57de7421df..6d01be60f5 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu index 53974f3e61..968bbf36f8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu index 24e1f63563..d564a62211 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu index a2fc325dad..cb5bccc176 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu index 2c1f5f56f1..146a7bc343 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu index 7cbdff3e8a..a195e0931c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu index b81bf0b99b..045fc71bed 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu index 88a00e9121..a31da2eddf 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu index c28edfd8f9..7382b58a23 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu index dbcd163308..87ca31ce90 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu index 63620ec90a..60f4d6ebbf 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu index d8c11ee6a0..e0d5d318bc 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu index 4af31d0bf9..dec7db046b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu index c7a04dc47b..7b71f43522 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu index 9bca3a1c5e..08fc989af8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu index acd0fa660f..2cc8b5b86d 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu index a38430fb3c..644e268469 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu index 03bb0516f7..1ebcec8b3f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu index 8ea90bd417..780ade7f6b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu index f914432642..bfcffe2a39 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu index e7e1cecd1f..ba4ba78ad4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu index 18b79da92c..f04260ba4f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu index 1c1c9470d6..33c78e5305 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu index 6cadc2641d..8388420921 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000..8d037153cb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu new file mode 100644 index 0000000000..c62e0b8d82 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..5e22d67f70 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000..1e005b3f01 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..96c4f55afd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu new file mode 100644 index 0000000000..8a92fe291e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..f47cb32667 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu new file mode 100644 index 0000000000..1915feb046 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu new file mode 100644 index 0000000000..fbc1577661 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000..88445691ff --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000..f7d051a34d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu new file mode 100644 index 0000000000..c83c1741d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..2e06c89a8c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000..46479ec15e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..18681ec42b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu new file mode 100644 index 0000000000..d2245aa136 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..022cdd3957 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu new file mode 100644 index 0000000000..67a324d52e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu new file mode 100644 index 0000000000..664f88dbfc --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000..6bd6b9ab38 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000..2f4ceaaed5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 0000000000..5fd59af348 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..e0f885b0f7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000..6dcda01962 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..5d20be6d2a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 0000000000..47463a7151 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..622b5533ce --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 0000000000..c83f44722c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 0000000000..5c9130f864 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000..a152022cb6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000..ef05aa2038 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 0000000000..19fe6d94f7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..6eb2d3d134 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000..ffbc998212 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..3d35075b48 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 0000000000..c2af33cf53 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..e07547c92d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 0000000000..1a04eb01f5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 0000000000..da9afc1157 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000..5e63a15515 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu index 4b650f53cf..4134d7d80b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu index 29cb3fe18b..11e3503b0d 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu index 2612bc9c98..67e39bd737 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu index 4c5fae060a..c37844daa5 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu index c0b58521bc..f0c40e2f89 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu index 0a05884724..3ed9694908 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu index b421199714..4a16aae66c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu index 7f337595b5..b5b5fc26b2 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu index c4c35a18c7..3b29be627e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu index 9ea549e117..5f1c298c4c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu index 8ffc852e3e..64895643d2 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu index 7143da2f79..dd508590d6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu index 4f7cd4f8e4..8411b6fccb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu index 5a9bb14205..b5b4f40770 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu index dc9b71a5b3..e608da04b0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu index 4c5440436a..c69b78ac3b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu index d988a48f99..170cdb5cb8 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu index c6ae246e7f..ef0d1e921c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu index 761a625564..6a7fc29ddd 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu index a74d7c2c3b..faeb6c487f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu index 6d48fb099b..655258d519 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu index 0e49f26aaa..4bd8ad8f26 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu index f780a8eb73..657820f285 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu index 948c8b17c8..cb0955d1a5 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu index 519783851f..357b64e83b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu index d5392ef3b0..c120792586 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu index 06086d4084..21687f8932 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu index a15ab4c60d..4df8ed64d7 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu index 7038c0ad72..b601195d7e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu index 9a805fd3e5..ced4753189 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu index b23cb43e77..03090f73cb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu index c18f470fcf..d6fe1559ca 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu index d61b04a07e..7b5ae4a56a 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu index 1d33fe12e0..6c603b4dca 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu index 03ac4d2f84..26d25fc190 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu index 7b031a4903..05a0baf18b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu index 77dbc58123..3a45776537 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 6bae5faa53..9b80bae51f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu index 30f666a73f..f6810efafb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu index 358e813eca..98c018893f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu index f5df3f502b..a10dfaca72 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu index f16185c3ac..b912a81443 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu index 796e4d63a3..8603c396e1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu index 6eeb977415..dc55dbc66a 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu index aa1d2cd05f..ef48844972 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu index 5a92ebdddf..b1c0ead6e5 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu index 78c390e5ef..5d76d0fff0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu index 2b5aaff0d8..44ea823d27 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu index f0fa3ac63d..30fe623508 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu index 0d9407b2ce..6eb12dc80a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu index 223b6783e0..b806fc9d50 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu index 2f49d5f5aa..8f0a26da03 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu index 9661156d88..6de2819a17 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu index b5f6d7f875..16927295b8 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu index 82b827e180..0841307209 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 042dd0cc71..7d4dcdc293 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu index 4712aed6c3..b4dfbf7f8b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu index 8295033dee..1fa048752d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu index 21c43e6dbd..e0b6a75e63 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu index d3317ad628..e257b42f79 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu index 86218988c2..f97ab4733a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu index 7a6450373c..cee43ef94c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu index 34c1a3d3f0..0442e1f94b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu index 96affd254c..bc71fa9e71 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu index 489717ff2c..b61dd71885 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu index 69917aa1ea..f47e1f5cda 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu index 3e3cc66f66..215752f1b0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu index e5f53e49c5..207afc7924 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu index 0899aa8987..6c38c08338 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu index 22f4cf6b14..dc2eb35dc2 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu index d601d694d4..f04e8bca6f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu index 1c5ba9b006..2697f6910e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu index 8073b677a1..e7a98b2e6e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu index 857be35920..98fb39c86e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu index 6931ffa279..cb938ad93b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu index 84facb47e7..e2dc45c79c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu index 878d160ff2..64f99c05a3 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu index e5561f7d63..3fdbbf23ba 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu index 30474d3543..ffe202ee39 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu index 074f7232f2..42740f0228 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu index 734abb7b0e..829929980d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu index 285e7ef520..d6a330432a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu index d552e45db1..39c774e6f7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 64ca02345e..bc54be11e6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu index 3d8bb7c277..a68790500d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu index 6fab8802c5..3bca3065c7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu index 1fb30696dd..985692b9fa 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu index af9b88d9a3..3c99cb6b5a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu index 5f9794a987..cf77a1ae81 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu index c906649acc..f9a46a44dd 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu index 2d7ac26e25..9b4dbbba58 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu index 171f28e9ce..da5373fd13 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu new file mode 100644 index 0000000000..ddd8bf07c4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu new file mode 100644 index 0000000000..c9494c4f1d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..4b2ec583cf --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu new file mode 100644 index 0000000000..306722d458 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..e44b2d2465 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu new file mode 100644 index 0000000000..d52417daef --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..6428c461aa --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu new file mode 100644 index 0000000000..d0df6306e2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu new file mode 100644 index 0000000000..e116d3ea7c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_split_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu new file mode 100644 index 0000000000..bededf4a7d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu new file mode 100644 index 0000000000..526a51fb71 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu new file mode 100644 index 0000000000..4e5d9cc4fe --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu new file mode 100644 index 0000000000..f553af139f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu new file mode 100644 index 0000000000..aa2a8260d2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..bbc4449ba2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu new file mode 100644 index 0000000000..02ca85ad67 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..d090fde972 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu new file mode 100644 index 0000000000..d48f60ad7e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu new file mode 100644 index 0000000000..9dda19d1ce --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu new file mode 100644 index 0000000000..f3e51fc9eb --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu new file mode 100644 index 0000000000..ea53102793 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu new file mode 100644 index 0000000000..10d86e5e99 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu new file mode 100644 index 0000000000..375197ef75 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu new file mode 100644 index 0000000000..4fc4831cf5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 0000000000..a3d94a163a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu new file mode 100644 index 0000000000..9663103ae1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 0000000000..b7d2b07ca8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu new file mode 100644 index 0000000000..471b5abaaf --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu new file mode 100644 index 0000000000..10f72182fa --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_split_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu new file mode 100644 index 0000000000..54db60c23b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/mainloop_bwd_sm80.hpp b/hopper/mainloop_bwd_sm80.hpp index e7b3d2dead..23baae6173 100644 --- a/hopper/mainloop_bwd_sm80.hpp +++ b/hopper/mainloop_bwd_sm80.hpp @@ -13,6 +13,7 @@ #include "seqlen.h" #include "mask.h" +#include "mask.h" #include "softmax.h" #include "utils.h" @@ -38,7 +39,6 @@ struct CollectiveMainloopBwdSm80 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - using SeqlenInfo_t = flash::SeqlenInfoQK(TileShape_MNK{}))>; static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; static constexpr bool SdP_swapAB = SdP_swapAB_; @@ -51,6 +51,9 @@ struct CollectiveMainloopBwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + static_assert(ArchTag::kMinComputeCapability >= 80); static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; @@ -281,8 +284,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -293,7 +298,7 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -312,8 +317,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -325,7 +332,8 @@ struct CollectiveMainloopBwdSm80 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int *const dq_semaphore; @@ -338,6 +346,9 @@ struct CollectiveMainloopBwdSm80 { static Params to_underlying_arguments(Arguments const& args) { if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -349,39 +360,19 @@ struct CollectiveMainloopBwdSm80 { // (the original softmax_scale) at the end. return {args.ptr_Q, args.shape_Q, args.stride_Q, args.ptr_K, args.shape_K, args.stride_K, - args.ptr_V, args.stride_V, - args.ptr_dO, args.stride_dO, + args.ptr_V, args.shape_V, args.stride_V, + args.ptr_dO, args.shape_dO, args.stride_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; } - CUTLASS_DEVICE - cute::tuple get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int n_block, int bidb) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); - } - } - int m_block_min = 0; - if constexpr (Is_causal || Is_local) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); - } - return {m_block_min, m_block_max}; - } - template CUTLASS_DEVICE bool mma(Params const& params, @@ -400,7 +391,9 @@ struct CollectiveMainloopBwdSm80 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto m_block_min_max = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto m_block_min_max = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); int const m_block_min = get<0>(m_block_min_max); int const m_block_max = get<1>(m_block_min_max); // It's possible to have m_block_max <= m_block_min. Exit early @@ -428,9 +421,9 @@ struct CollectiveMainloopBwdSm80 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), @@ -542,13 +535,16 @@ struct CollectiveMainloopBwdSm80 { for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOsdO))); + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); } int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, - params.qhead_per_khead_divmod + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); { @@ -560,9 +556,12 @@ struct CollectiveMainloopBwdSm80 { Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + #pragma unroll + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); } // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; // static_assert(EvenN); // It simplifies the loading of K and V @@ -582,7 +581,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); } } } @@ -595,7 +594,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); } } } @@ -668,7 +667,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; #pragma unroll for (int k = 0; k < size<2>(tdOsdO); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); } } } @@ -832,21 +831,21 @@ struct CollectiveMainloopBwdSm80 { // if (cute::thread0()) { print_tensor(tdVrdV); } __syncthreads(); // make sure sdS is written auto do_mma_dQ = [&] (auto hook) { - Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); - clear(tdQrdQ); - Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); - Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); - flash::gemm_sm80( - tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, - // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); - // if (cute::thread0()) { print_tensor(tdQrdQ); } - // We can reuse r2s_thr_copy_dQaccum for this partitioning - Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); - Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); - static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); - #pragma unroll - for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); + clear(tdQrdQ); + Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); + Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); + flash::gemm_sm80( + tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, + // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); + // if (cute::thread0()) { print_tensor(tdQrdQ); } + // We can reuse r2s_thr_copy_dQaccum for this partitioning + Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); + Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); + static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } }; // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } @@ -861,7 +860,7 @@ struct CollectiveMainloopBwdSm80 { tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); } - if constexpr (kStages == 1) { + if constexpr (kStages == 1) { __syncthreads(); do_mma_dQ(load_Q_next); } diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 393a6e5814..ec34e20eca 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -17,6 +17,7 @@ #include "named_barrier.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "softmax.h" #include "utils.h" @@ -48,7 +49,6 @@ struct CollectiveMainloopBwdSm90 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - using SeqlenInfo_t = flash::SeqlenInfoQK(TileShape_MNK{}))>; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; @@ -60,6 +60,9 @@ struct CollectiveMainloopBwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQK; + using BlockMN_t = flash::BlockMN; + static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); @@ -295,8 +298,10 @@ struct CollectiveMainloopBwdSm90 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -307,7 +312,7 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -321,6 +326,8 @@ struct CollectiveMainloopBwdSm90 { struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; + ShapeQKV const shape_V; + ShapeQKV const shape_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; @@ -334,7 +341,8 @@ struct CollectiveMainloopBwdSm90 { float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int* const dq_semaphore; @@ -353,7 +361,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, @@ -367,7 +375,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, @@ -375,6 +383,9 @@ struct CollectiveMainloopBwdSm90 { TileShape_MNK{}, ClusterShape{}); // no mcast for KV if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val @@ -384,14 +395,15 @@ struct CollectiveMainloopBwdSm90 { // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. - return {args.shape_Q, args.shape_K, + return {args.shape_Q, args.shape_K, + args.shape_V, args.shape_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; @@ -406,26 +418,6 @@ struct CollectiveMainloopBwdSm90 { cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } - CUTLASS_DEVICE - cute::tuple get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int n_block, int bidb) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; - int m_block_max = cute::ceil_div(seqlen_q, kBlockM); - if constexpr (Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); - } - } - int m_block_min = 0; - if constexpr (Is_causal || Is_local) { - m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); - } - return {m_block_min, m_block_max}; - } - template CUTLASS_DEVICE void load(Params const& params, @@ -443,7 +435,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, + params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { @@ -468,9 +462,9 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); @@ -609,7 +603,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return; } @@ -697,7 +693,9 @@ struct CollectiveMainloopBwdSm90 { bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; - auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); + auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( + seqlen_info, n_block, bidb, params.window_size_left, + params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } @@ -803,8 +801,8 @@ struct CollectiveMainloopBwdSm90 { // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, - params.qhead_per_khead_divmod + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); int m_block = m_block_min; diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index e43904518c..297927bfda 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -12,6 +12,7 @@ #include "cute/tensor.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" @@ -22,7 +23,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm80 { @@ -30,6 +31,7 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; static_assert(kStages > 0, "kStages must be greater than 0"); using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -43,7 +45,6 @@ struct CollectiveMainloopFwdSm80 { static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool Transpose_V = Is_FP8; - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 80); @@ -53,6 +54,9 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + using MMA_Atom_Arch = std::conditional_t< ArchTag::kMinComputeCapability >= 80, std::conditional_t< @@ -177,12 +181,15 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -195,7 +202,7 @@ struct CollectiveMainloopFwdSm80 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -205,6 +212,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -218,6 +226,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -239,7 +248,8 @@ struct CollectiveMainloopFwdSm80 { float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -248,6 +258,7 @@ struct CollectiveMainloopFwdSm80 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; static Params @@ -266,13 +277,16 @@ struct CollectiveMainloopFwdSm80 { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -283,41 +297,11 @@ struct CollectiveMainloopFwdSm80 { args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; - } - - CUTLASS_DEVICE - cute::tuple get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_k = seqlen_info.seqlen_k; - int const seqlen_q = seqlen_info.seqlen_q; - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { - int m_idx_max = (m_block + 1) * kBlockM; - if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } - n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN)); - } - int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN); - } - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; - n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); - } - // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } template @@ -340,7 +324,10 @@ struct CollectiveMainloopFwdSm80 { int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto n_block_min_max = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto n_block_min_max = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); int const n_block_min = get<0>(n_block_min_max); int const n_block_max = get<1>(n_block_min_max); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier @@ -430,12 +417,15 @@ struct CollectiveMainloopFwdSm80 { } cute::cp_async_fence(); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { @@ -489,11 +479,11 @@ struct CollectiveMainloopFwdSm80 { flash::cp_async_wait(); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = cute::conditional_return( @@ -561,8 +551,8 @@ struct CollectiveMainloopFwdSm80 { if constexpr (!Share_QV_Smem) { preprocess_Q(); } flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, - params.qhead_per_khead_divmod + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); float softcap_val = params.softcap_val; @@ -635,19 +625,17 @@ struct CollectiveMainloopFwdSm80 { --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -660,12 +648,6 @@ struct CollectiveMainloopFwdSm80 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); @@ -674,19 +656,6 @@ struct CollectiveMainloopFwdSm80 { return true; } - CUTLASS_DEVICE - cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); - int const n_block_new_min = idx_k_new_min / kBlockN; - int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; - // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} - return {n_block_new_min, n_block_new_max}; - } - template CUTLASS_DEVICE bool store_kv_new(Params const& params, @@ -696,7 +665,10 @@ struct CollectiveMainloopFwdSm80 { cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto n_block_new_min_max = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); int const n_block_new_min = get<0>(n_block_new_min_max); int const n_block_new_max = get<1>(n_block_new_min_max); if (n_block_new_max <= n_block_new_min) { return false; } @@ -723,20 +695,23 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, + params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position + 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); static_assert(std::is_same_v); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index dbbf2f8f82..536ff855fd 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -16,6 +16,7 @@ #include "named_barrier.hpp" #include "seqlen.h" +#include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" @@ -27,14 +28,16 @@ namespace flash { using namespace cute; -template +template struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -43,17 +46,19 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; - static constexpr bool PagedKV = PagedKV_; + static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; static constexpr bool Use_TMA_Q = !PackGQA; - static constexpr bool Use_TMA_KV = !PagedKV; + static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); - using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -64,34 +69,57 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + using SeqlenInfo_t = flash::SeqlenInfoQKNewK; + using BlockMN_t = flash::BlockMN; + + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. - static constexpr bool Mma0_is_RS = false; - // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled"); - static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); - static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); - - using AtomLayoutMNK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + + // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write + static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; + + using AtomLayoutQK = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma0_is_RS, + !MmaQK_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, - AtomLayoutMNK{})); - using TiledMma1 = decltype(cute::make_tiled_mma( + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma1_is_RS, + !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, - AtomLayoutMNK{})); - - static constexpr int NumMmaThreads = size(TiledMma0{}); + AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); + // For hdim64,512, WG1 can use RS but WG2 must use SS + using TiledMmaPV_RS = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(), + AtomLayoutPV{})); + + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -107,54 +135,67 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + Int, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{})), Int>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + using SmemCopyAtomP = Copy_Atom; // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. - static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0)); - static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0; - // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). - static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0)); - using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; - using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; using LDSM_value_shape = Shape<_2, _2, _1, _4>; using LDSM_value_stride = Stride<_1, _2, _16, _4>; - using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; using S2RTiledCopyVt = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); - using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; - using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; using STSM_value_shape = Shape<_1, _4, _2, _2>; using STSM_value_stride = Stride<_0, _1, _4, _8>; using STSM_divide_shape = Shape<_8, _16>; - // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. // using STSM_value_shape = Shape<_2, _4, _1, _2>; @@ -167,19 +208,20 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); - // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where + // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemLayoutAtom = Layout, Int>, @@ -221,14 +263,22 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -241,65 +291,88 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". - struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; }; - struct TensorStorageWithPNoTranspose : cute::aligned_struct { + struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; - using TensorStorageNoTranspose = std::conditional_t; + using TensorStorageNoTranspose = std::conditional_t< + MmaPV_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); - struct TensorStorageTransposeV : cute::aligned_struct { + struct TensorStorageTransposeV : cute::aligned_struct { cute::array_aligned, SmemAlignmentV> smem_v; cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemScale_t smem_scale; }; using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = IntraWGOverlap + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2; - static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; + static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; - Element* const ptr_K; // Not Element const* since we might append to KV cache in-place + Element* const ptr_K; // not Element const* since we might append to KV cache in-place ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -312,7 +385,7 @@ struct CollectiveMainloopFwdSm90 { float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; - int const window_size_left = -1, window_size_right = -1, sink_token_length = 0; + int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; @@ -322,6 +395,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const* const seqlens_rotary = nullptr; }; // Device side kernel params @@ -335,12 +409,17 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -351,17 +430,20 @@ struct CollectiveMainloopFwdSm90 { ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; cutlass::FastDivmod page_size_divmod; + cutlass::FastDivmod blockN_per_page_size_divmod; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_K tma_load_K; TMA_V tma_load_V; TMA_K tma_load_K_new; TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; - int const window_size_left, window_size_right, sink_token_length; + int const window_size_left, window_size_right; + cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; @@ -370,6 +452,7 @@ struct CollectiveMainloopFwdSm90 { int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; + int const *const seqlens_rotary = nullptr; }; static Params @@ -388,12 +471,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( @@ -402,13 +487,29 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), select<1, 0, 2, 3>(args.shape_K_new), select<1, 0, 2, 3>(args.stride_V_new)); + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); TMA_V tma_load_V_new = make_tma_copy( GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( @@ -419,33 +520,51 @@ struct CollectiveMainloopFwdSm90 { args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); + int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); + if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { + assert(page_size % kBlockN == 0); + assert(!args.leftpad_k); + } + // Avoid dividing by zero + cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); + attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, - cutlass::FastDivmod(int(get<0>(args.shape_K))), + cutlass::FastDivmod(page_size), // page_size_divmod + cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, - args.window_size_left, args.window_size_right, args.sink_token_length, + args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, - args.seqused_q, args.seqused_k, args.leftpad_k}; + args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -453,6 +572,9 @@ struct CollectiveMainloopFwdSm90 { static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_Q) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } } if constexpr (Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); @@ -464,37 +586,6 @@ struct CollectiveMainloopFwdSm90 { } } - CUTLASS_DEVICE - cute::tuple get_n_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int const seqlen_k = seqlen_info.seqlen_k; - int const seqlen_q = seqlen_info.seqlen_q; - int n_block_max = cute::ceil_div(seqlen_k, kBlockN); - if constexpr (Is_causal || Is_local) { - int m_idx_max = (m_block + 1) * kBlockM; - // TODO: check off-by-1 error - if (PackGQA) { m_idx_max = params.qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } - n_block_max = std::min(n_block_max, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + params.window_size_right, kBlockN)); - } - int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = params.qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - params.window_size_left) / kBlockN); - } - // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - if constexpr (Split) { - int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits); - n_block_min = n_block_min + split_idx * num_n_blocks_per_split; - n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); - } - // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; - } - template CUTLASS_DEVICE void load(Params const& params, @@ -509,8 +600,15 @@ struct CollectiveMainloopFwdSm90 { int &work_idx ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { @@ -541,6 +639,7 @@ struct CollectiveMainloopFwdSm90 { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; @@ -554,31 +653,50 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } - Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); } // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); - Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k) + Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); - Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) + Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); + + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx ); // Set up for transposing V, only used if Transpose_V @@ -630,9 +748,10 @@ struct CollectiveMainloopFwdSm90 { auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { pipeline_k.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tKgK_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); @@ -643,9 +762,10 @@ struct CollectiveMainloopFwdSm90 { auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); pipeline_v_load.producer_acquire(smem_pipe_write); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { + auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), - tVgVt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); @@ -666,7 +786,7 @@ struct CollectiveMainloopFwdSm90 { pipeline_v.producer_commit(smem_pipe_write); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. Without this we get race conditions. - cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::ProducerWG) /*id*/); + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); pipeline_vt.consumer_release(smem_pipe_read); }; @@ -678,8 +798,10 @@ struct CollectiveMainloopFwdSm90 { bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.template load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} @@ -690,16 +812,21 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -707,6 +834,15 @@ struct CollectiveMainloopFwdSm90 { auto &barrier_Q = shared_storage.pipelines.barrier_Q; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } } // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem @@ -726,8 +862,10 @@ struct CollectiveMainloopFwdSm90 { PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind ++smem_pipe_write; if (should_load_KV) { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); + } else { + paged_kv_manager.load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); @@ -742,33 +880,6 @@ struct CollectiveMainloopFwdSm90 { n_block_prev = n_block; if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } } - // if constexpr (Is_local) { - // Disable sink token code for now - if constexpr (false && Is_local) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - #pragma unroll 1 - for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind - ++smem_pipe_write; - if (should_load_KV) { - if constexpr (PagedKV) { - paged_kv_manager.template load_page_table(n_block); - } - if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } - load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - if constexpr (!Transpose_V) { - if constexpr (IntraWGOverlap) { - load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); - } else { - load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); - } - } - } - n_block_prev = n_block; - if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } - } - } scheduler_prefetch(); if constexpr (!Transpose_V && IntraWGOverlap) { if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } @@ -821,13 +932,19 @@ struct CollectiveMainloopFwdSm90 { CUTLASS_DEVICE void mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (UseSchedulerBarrier) { // We have NamedBarrier for up to 3 WGs static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); // WG1 needs the very first signal to start - if (flash::canonical_warp_group_idx_nosync() == 1) { + if (warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } @@ -857,7 +974,10 @@ struct CollectiveMainloopFwdSm90 { int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -867,57 +987,81 @@ struct CollectiveMainloopFwdSm90 { Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = [&] { - if constexpr (Mma1_is_RS) { - // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); } }(); - - if constexpr (!Mma0_is_RS) { - static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and - stride<0>(typename TiledMma0::BLayout{}) == 0 and - size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); + + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma0 tiled_mma0; - TiledMma1 tiled_mma1; - auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); - - auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" - Tensor tSrQ = wg_mma0.partition_fragment_A(sQ); - Tensor tSrK = wg_mma0.partition_fragment_B(sK); - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter - clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; - flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, - params.qhead_per_khead_divmod + flash::Mask mask( + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + params.attention_chunk_divmod, params.qhead_per_khead_divmod ); float softcap_val = params.softcap_val; @@ -926,22 +1070,37 @@ struct CollectiveMainloopFwdSm90 { float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; softcap_val *= q_descale * k_descale; } - // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn - // -inf to e.g. -50.0, which can affect the attention softmax. + // Softcapping needs to happen before masking since if we apply after masking, softcapping + // can turn -inf to e.g. -50.0, which can affect the attention softmax. auto scoremod_premask_fn = [&](auto& tSrS) { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; + auto write_P_to_smem = [&](auto& tOrP) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); + }; + + auto arrive_on_P_write_barrier = [&] { + cutlass::arch::fence_view_async_shared(); + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } + }; + auto &barrier_Q = shared_storage.pipelines.barrier_Q; if constexpr (!AppendKV) { barrier_Q.wait(work_idx % 2); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_q, + seqlen_info.seqlen_rotary); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { @@ -961,93 +1120,106 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } } - if constexpr (Mma0_is_RS) { + if constexpr (MmaQK_is_RS) { using SmemCopyAtomQ = Copy_Atom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma0); + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } - // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask.template apply(tSrS, m_block, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { - cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); - cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 - } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } --n_block; + // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { static constexpr bool Check_inf = decltype(check_inf_type)::value; PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!Mma1_is_RS) { - cutlass::arch::fence_view_async_shared(); - __syncwarp(); - } + if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } }; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -1060,20 +1232,19 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); - // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang softmax.rescale_o(tOrO, scores_scale); @@ -1087,27 +1258,50 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + auto smem_pipe_read_prev = smem_pipe_read; + if constexpr (!Is_first_iter) { ++smem_pipe_read; } + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); - warp_scheduler_barrier_arrive(); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read); // release K + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + if constexpr (!HasQv) { + warp_scheduler_barrier_arrive(); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + } else { + if constexpr (Is_first_iter) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + } + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read); // release K + warpgroup_wait<0>(); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); + if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!MmaPV_use_RS_WG1) { + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } else { + TiledMmaPV_RS tiled_mma_pv_rs; + flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + } + if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } + warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V - ++smem_pipe_read; }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; @@ -1115,19 +1309,17 @@ struct CollectiveMainloopFwdSm90 { --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; - int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); - int const n_block_min_causal_local_mask = - std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); + int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_right, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } - int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; - int const n_block_min_before_local_mask = !Is_local - ? n_block_min - : std::max(n_block_min, - cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); + int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( + seqlen_info, m_block, n_block_min, params.window_size_left, + params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { @@ -1140,36 +1332,114 @@ struct CollectiveMainloopFwdSm90 { for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } - // Disable sink token code for now - // int n_block_sink_max = cute::ceil_div(params.sink_token_length, kBlockN); - // #pragma unroll 1 - // for (n_block = std::min(n_block, n_block_sink_max - 1); n_block >= 0; --n_block) { - // fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); - // } } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; } ++work_idx; return true; } - CUTLASS_DEVICE - cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, - int m_block, int bidb, int split_idx=0, int num_splits=1) { - static constexpr int kBlockN = get<1>(TileShape_MNK{}); - auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, num_splits); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); - int const n_block_new_min = idx_k_new_min / kBlockN; - int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; - // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} - return {n_block_new_min, n_block_new_max}; + template + CUTLASS_DEVICE bool + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + // clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + int n_block = n_block_max - 1; + // If HasQv, then by the time P is ready, V must have been ready as well + if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + --n_block; + + #pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + if constexpr (!HasQv) { + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + } + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; } template @@ -1185,7 +1455,11 @@ struct CollectiveMainloopFwdSm90 { ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); + if (n_block_new_max <= n_block_new_min) { return false; } Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); @@ -1207,10 +1481,11 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) @@ -1283,7 +1558,10 @@ struct CollectiveMainloopFwdSm90 { cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = get_n_block_k_new_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } // as_position_independent_swizzle_tensor makes address calculation easier @@ -1302,27 +1580,32 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); - int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, - params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); + params.is_rotary_interleaved, thread_idx, seqlen_k_new, + seqlen_info.seqlen_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + // This is used to index into the batch dimension of mK and mV + int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; + + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, - params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k + params.ptr_V, params.headdim_v, params.stride_V, + params.page_size_divmod, params.blockN_per_page_size_divmod, + bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1335,7 +1618,7 @@ struct CollectiveMainloopFwdSm90 { } static_assert(std::is_same_v); - static_assert(!PagedKV || std::is_same_v); + static_assert(!PagedKVNonTMA || std::is_same_v); GmemTiledCopyAppendKV gmem_tiled_copy_kv; auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) @@ -1347,13 +1630,19 @@ struct CollectiveMainloopFwdSm90 { Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); if (get<1>(params.shape_rotary) <= 0) { pipeline_k_new.consumer_wait(smem_pipe_read); Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tKgK_cur = tKgK(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( @@ -1364,18 +1653,18 @@ struct CollectiveMainloopFwdSm90 { } } else { Tensor gK_cur = gK(_, _, n_block); - auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); + auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); + rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); } else { auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); - rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); + rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } - // Without this sync I'm getting race condition when seqlen_k is large + // Without this fence I'm getting race condition when seqlen_k is large cutlass::arch::fence_view_async_shared(); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. @@ -1388,11 +1677,11 @@ struct CollectiveMainloopFwdSm90 { pipeline_v_new.consumer_wait(smem_pipe_read); int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tKcK, tKpK, n_limit); + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } @@ -1403,7 +1692,7 @@ struct CollectiveMainloopFwdSm90 { #pragma unroll 1 for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { - if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } + if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } store_K(n_block, smem_pipe_read); // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } store_V(n_block, smem_pipe_read); diff --git a/hopper/mask.h b/hopper/mask.h index 02d046268c..d43e5ee156 100644 --- a/hopper/mask.h +++ b/hopper/mask.h @@ -22,11 +22,13 @@ struct Mask { int const thread_idx; int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; + cutlass::FastDivmod const attention_chunk_divmod; cutlass::FastDivmod const qhead_per_khead_divmod; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, + cutlass::FastDivmod const &attention_chunk_divmod, cutlass::FastDivmod const &qhead_per_khead_divmod) : thread_idx(thread_idx) , seqlen_q(seqlen_q) @@ -34,6 +36,7 @@ struct Mask { , window_size_left(window_size_left) , window_size_right(window_size_right) , sink_token_length(sink_token_length) + , attention_chunk_divmod(attention_chunk_divmod) , qhead_per_khead_divmod(qhead_per_khead_divmod) { }; @@ -100,16 +103,21 @@ struct Mask { } else { int const local_row_offset_right = causal_row_offset + window_size_right; int const local_row_offset_left = causal_row_offset - 1 - window_size_left; - int const col_limit_sink = sink_token_length - n_block * kBlockN; + int const col_limit_sink = sink_token_length - n_block * kBlockN; // TODO: subtract thread_col_offset? #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); - int const col_limit_right = !Seqlenk_mask + int col_limit_right = !Seqlenk_mask ? row_idx + local_row_offset_right : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); - int const col_limit_left = row_idx + local_row_offset_left; + int col_limit_left = row_idx + local_row_offset_left; + if (attention_chunk_divmod.divisor > 0) { + int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset; + col_limit_left = std::max(col_limit_left, col_limit_left_chunk); + col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor); + } #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col_idx = int(get(t0ScS_rowcol(m, n))); @@ -118,6 +126,7 @@ struct Mask { } } } else { + // TODO: backward does not support attention_chunk yet int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; if constexpr (Causal_mask) { diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index f77ea77829..a7dfb6439a 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -49,28 +49,24 @@ static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNa enum class FwdNamedBarriers { QueryEmpty = 0, - ProducerWG = 1, - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - WarpSchedulerWG1 = 4, - WarpSchedulerWG2 = 5, - WarpSchedulerWG3 = 6, - AppendKV = 7, - QueryRotated = 8, + WarpSchedulerWG1 = 1, + WarpSchedulerWG2 = 2, + WarpSchedulerWG3 = 3, + AppendKV = 4, + QueryRotated = 5, + PFull = 6, + PEmpty = 7, }; enum class BwdNamedBarriers { KVEmpty = 0, PdS = 1, - // This needs to match FwdNamedBarriers::TileCountSmemEmpty since TileScheduler uses it - TileCountSmemEmpty = 2, - TileCountSmemFull = 3, - dQEmptyWG1 = 4, - dQEmptyWG2 = 5, - dQEmptyWG3 = 6, - dQFullWG1 = 7, - dQFullWG2 = 8, - dQFullWG3 = 9, + dQEmptyWG1 = 2, + dQEmptyWG2 = 3, + dQEmptyWG3 = 4, + dQFullWG1 = 5, + dQFullWG2 = 6, + dQFullWG3 = 7, }; } // flash diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 0f710e5493..9ea59bcc2a 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -14,7 +14,7 @@ namespace flash { using namespace cute; -template +template struct PagedKVManager { // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), // load_page_table(2), load_K(2), load_V(1), etc. @@ -23,14 +23,17 @@ struct PagedKVManager { // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for // rotary where we want each thread to have at least 2 loads per row. + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. - static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB"); - static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element); + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); @@ -59,6 +62,8 @@ struct PagedKVManager { using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, // since those require int64_t arithmetic. We optimize by having threads split this work. @@ -66,47 +71,63 @@ struct PagedKVManager { // that each thread needs to load for the case of hdim 128 and kBlockN = 176. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); using TensorPageOffset = decltype(make_tensor>(Shape>{})); using TensorKVPtr = decltype(make_tensor(Shape>{})); GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; cutlass::FastDivmod const &page_size_divmod; + cutlass::FastDivmod const &blockN_per_page_size_divmod; int const thread_idx; int const seqlen_k; int const leftpad_k; + int const* const ptr_page_table; GmemThrCopyKVCpAsync const gmem_thr_copy_kv; TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; TensortKpK tKpK; + TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; - + int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA CUTLASS_DEVICE - PagedKVManager(int const* const ptr_page_table, + PagedKVManager(int const* const ptr_page_table_, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, - Element* const ptr_V, StrideKV const &stride_V, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, - int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k + cutlass::FastDivmod const &blockN_per_page_size_divmod, + int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, + int bidb_kv_idx ) : page_size_divmod(page_size_divmod) + , blockN_per_page_size_divmod(blockN_per_page_size_divmod) , thread_idx(thread_idx) , seqlen_k(seqlen_k) , leftpad_k(leftpad_k) + , ptr_page_table(ptr_page_table_) , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) + , bidb_kv_idx(bidb_kv_idx) + , bidb_kv_idx_prev(bidb_kv_idx) { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); - mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); #pragma unroll for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } + tVpV = cute::conditional_return(tKpK, tVpV_); }; template @@ -131,6 +152,38 @@ struct PagedKVManager { if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } }; + template + CUTLASS_DEVICE + void load_page_table_TMA(const int n_block) { + // We require that page size is a multiple of kBlockN, and there's no leftpad_k + if (ptr_page_table) { + bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; + } else { + n_block_idx = n_block; + } + if constexpr (First_iter && !KV_Same_Iter) { + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + } + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_K_TMA() { + return {n_block_idx, bidb_kv_idx}; + }; + + CUTLASS_DEVICE + cute::tuple get_indices_for_V_TMA() { + if constexpr (KV_Same_Iter) { + return {n_block_idx, bidb_kv_idx}; + } else { + cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; + bidb_kv_idx_prev = bidb_kv_idx; + n_block_idx_prev = n_block_idx; + return indices; + } + }; + CUTLASS_DEVICE TensorKVPtr compute_K_ptr() { Tensor tPrKPtr = make_tensor(Shape>{}); @@ -200,27 +253,27 @@ struct PagedKVManager { // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // Faster to rely on the cp.async to clear smem that are out of bound, // rather than calling cute::clear directly. // We have to be careful not to write to smem past `kBlockN` if !EvenN. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) { - bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); } } } @@ -269,24 +322,24 @@ struct PagedKVManager { if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVrV); ++m) { - bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tVrV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tKpK(_0{}, k)) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); } } diff --git a/hopper/rotary.h b/hopper/rotary.h index 5e30456c2d..aa3602cc79 100644 --- a/hopper/rotary.h +++ b/hopper/rotary.h @@ -226,7 +226,7 @@ struct Rotary { // The main bottleneck here is actually instruction cache misses. - // Similar to PagedKV, it's expensive to compute the pointers. + // Similar to PagedKVNonTMA, it's expensive to compute the pointers. // We split the work among threads loading the same row, then __shfl_sync the pointers. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); Tensor tPrCosPtr = make_tensor(Shape>{}); @@ -350,7 +350,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -377,7 +377,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -385,7 +385,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKsK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); auto mK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); @@ -400,7 +400,7 @@ struct Rotary { Tensor rK = make_fragment_like(tKsK(_, m, k)); cute::copy(tiled_copy_k, tKsK(_, m, k), rK); if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } - if constexpr (!PagedKV) { + if constexpr (!PagedKVNonTMA) { cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); } else { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; @@ -412,7 +412,7 @@ struct Rotary { } }; - template + template CUTLASS_DEVICE void apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) @@ -439,7 +439,7 @@ struct Rotary { CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } @@ -449,7 +449,7 @@ struct Rotary { for (int m = 0; m < size<1>(tKcK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); Tensor gK_cur_copy = [&] { - if constexpr (PagedKV) { + if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); diff --git a/hopper/seqlen.h b/hopper/seqlen.h index 21a7471280..5547238b34 100644 --- a/hopper/seqlen.h +++ b/hopper/seqlen.h @@ -64,12 +64,13 @@ struct SeqlenInfoQKNewK { int const leftpad_k; int const offset_q, offset_k, offset_k_new; - int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k; + int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; CUTLASS_DEVICE SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, - int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k + int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, + int const* const seqlens_rotary ) : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) @@ -85,6 +86,7 @@ struct SeqlenInfoQKNewK { ? 0 : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) + , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) { } diff --git a/hopper/setup.py b/hopper/setup.py index d95be9ad40..850fb0b520 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -33,7 +33,7 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -PACKAGE_NAME = "flash_attn" +PACKAGE_NAME = "flash_attn_3" BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" @@ -64,6 +64,8 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -90,7 +92,9 @@ def _write_ninja_file(path, objects, ldflags, library_target, - with_cuda) -> None: + with_cuda, + **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension + ) -> None: r"""Write a ninja file that does the desired compiling and linking. `path`: Where to write this file @@ -150,6 +154,8 @@ def sanitize_flags(flags): flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') + cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') @@ -182,10 +188,13 @@ def sanitize_flags(flags): # to make this work on Windows too. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' ] cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + ] + cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') @@ -199,6 +208,8 @@ def sanitize_flags(flags): rule = 'cuda_compile' elif source_file.endswith('_sm80.cu'): rule = 'cuda_compile_sm80' + elif source_file.endswith('_sm100.cu'): + rule = 'cuda_compile_sm100' else: rule = 'cuda_compile_sm80_sm90' else: @@ -244,6 +255,7 @@ def sanitize_flags(flags): blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] blocks += [devlink_rule, link_rule, build, devlink, link, default] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file @@ -333,22 +345,19 @@ def open_url(url): return urllib.request.urlopen(request, timeout=300) -def download_and_copy(name, src_path, dst_path, version, url_func): +def download_and_copy(name, src_func, dst_path, version, url_func): if is_offline_build(): return flashattn_cache_path = get_flashattn_cache_path() base_dir = os.path.dirname(__file__) system = platform.system() - try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - except KeyError: - arch = platform.machine() + arch = platform.machine() + arch = {"arm64": "aarch64"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path - platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" - src_path = src_path(platform_name, version) if callable(src_path) else src_path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if download: @@ -364,11 +373,13 @@ def download_and_copy(name, src_path, dst_path, version, url_func): def nvcc_threads_args(): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" + nvcc_threads = os.getenv("NVCC_THREADS") or "2" return ["--threads", nvcc_threads] -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +# NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} + exe_extension = sysconfig.get_config_var("EXE") @@ -384,28 +395,44 @@ def nvcc_threads_args(): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - check_if_cuda_home_none("flash_attn") + check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently + # ptxas 12.8 gives the best perf currently + # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 + # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. + if bare_metal_version != Version("12.8"): download_and_copy( - name="nvcc", src_path=f"bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) download_and_copy( - name="nvcc", src_path=f"nvvm/bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", + dst_path="nvvm/bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) base_dir = os.path.dirname(__file__) - ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin") + ctk_path_new = os.path.abspath(os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")) nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc + # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions @@ -443,10 +470,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -456,7 +486,18 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all"] + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -470,6 +511,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" @@ -484,6 +533,7 @@ def nvcc_threads_args(): ) if not DISABLE_SPLIT: sources += ["flash_fwd_combine.cu"] + sources += ["flash_prepare_scheduler.cu"] nvcc_flags = [ "-O3", "-std=c++17", @@ -493,9 +543,9 @@ def nvcc_threads_args(): # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers "--resource-usage", # printing out number of registers # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster - "-lineinfo", + "-lineinfo", # TODO: disable this for release to reduce binary size "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use - # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL + "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted ] @@ -513,13 +563,14 @@ def nvcc_threads_args(): ext_modules.append( CUDAExtension( - name="flash_attn_3_cuda", + name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + feature_args, + "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, + py_limited_api=True, ) ) @@ -628,4 +679,5 @@ def run(self): "packaging", "ninja", ], + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/hopper/sm90_pipeline_no_cluster.hpp b/hopper/sm90_pipeline_no_cluster.hpp index 65a3d1554b..1fb805aec1 100644 --- a/hopper/sm90_pipeline_no_cluster.hpp +++ b/hopper/sm90_pipeline_no_cluster.hpp @@ -39,7 +39,7 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 5e13b5f93a..15a7d51364 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -179,3 +179,26 @@ return __VA_ARGS__(); \ } \ }() + +#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE <= 1) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 1fe43e21fa..0b5a0e2af9 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -5,9 +5,13 @@ import pytest import torch import torch.nn.functional as F +from torch._C import parse_schema from einops import rearrange, repeat -from flash_attn.layers.rotary import apply_rotary_emb +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None from padding import pad_input, unpad_input from test_util import ( @@ -16,7 +20,8 @@ generate_random_padding_mask, ) -from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine, flash_attn_with_kvcache +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" @@ -50,6 +55,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -57,7 +64,7 @@ @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -68,7 +75,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -96,12 +103,12 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): - # sink_token_length = 0 if not local else 4 - sink_token_length = 0 if not local else 0 if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) @@ -113,148 +120,170 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # window_size = (-1, -1) if not local else (16, 0) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] - if V_colmajor: - v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( - q, - k, - v, + dv_vals = [d] + if has_qv: + dv_vals = [256, 512] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - sink_token_length=sink_token_length, + attention_chunk=attention_chunk, softcap=softcap, - pack_gqa=pack_gqa, - num_splits=num_splits + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: - g = torch.randn_like(out) - do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) - # import flash_attn_3_cuda - # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( - # g, - # q, - # k, - # v, - # out, - # lse, - # None, - # None, - # None, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # sink_token_length, - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -263,6 +292,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -270,7 +301,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -280,7 +311,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -303,254 +334,295 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = nheads + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4).detach().requires_grad_() - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] - query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False - ) - key_padding_mask = generate_random_padding_mask( - seqlen_k, batch_size, device, mode="random", zero_lengths=True - ) - - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): - if add_unused: - another_mask = generate_random_padding_mask(max_seq_len, bs, device) - attn_mask = torch.logical_and(padding_mask, another_mask) - unused_mask = torch.logical_xor( - torch.logical_or(padding_mask, another_mask), attn_mask - ) + dv_vals = [d] + if has_qv: + dv_vals = [256, 512] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: - attn_mask = padding_mask - unused_mask = None - return attn_mask, unused_mask - - query_padding_mask, query_unused_mask = _gen_unused_masks( - query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device - ) - key_padding_mask, key_unused_mask = _gen_unused_masks( - key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device - ) - - ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + ( q_unpad, k_unpad, v_unpad, + qv_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, causal=causal, - q_descale=q_descale, - k_descale=k_descale, v_descale=v_descale, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - out = output_pad_fn(out_unpad) - if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: - g_unpad = torch.randn_like(out_unpad) - do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) - # import flash_attn_3_cuda - # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( - # g_unpad, - # q_unpad, - # k_unpad, - # v_unpad, - # out_unpad, - # lse, - # None, - # None, - # None, - # cu_seqlens_q, - # cu_seqlens_k, - # None, None, - # max_seqlen_q, - # max_seqlen_k, - # d ** (-0.5), - # causal, - # window_size[0], window_size[1], - # softcap, - # deterministic, - # 0, # sm_margin - # ) - dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) - if key_unused_mask is not None: - k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") - dk.masked_fill_(k_zero_masking, 0.0) - dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: - dq.masked_fill_(q_zero_masking, 0.0) - # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") - # assert (softmax_d - do_o).abs().max().item() <= 1e-5 - # assert dq_accum.abs().max().item() == 0.0 - g = output_pad_fn(g_unpad) + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + # pack_gqa_vals = [False] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1] + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("seqused_q: ", seqused_q) + # print("seqused_k: ", seqused_k) + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() - # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) - # P = torch.softmax(qk, -1) - # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) - # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) - # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) - # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) - # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) - dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) - print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") - print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") - print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") - print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") - print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") - print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") - print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") - print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") - print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") - print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") - print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") - print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") - # breakpoint() + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: - dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol - dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol - dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("num_splits", [1] + ([0] if not DISABLE_SPLIT else [])) -# @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if not DISABLE_APPENDKV else [0.0]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) # @pytest.mark.parametrize("page_size", [None]) @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) # @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @@ -558,6 +630,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -569,8 +642,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): (3, 799), (64, 2048), (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), + # (1, 128 * 1024), + # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler @@ -587,12 +660,12 @@ def test_flash_attn_kvcache( page_size, rotary_fraction, rotary_interleaved, + has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, new_kv, mha_type, - num_splits, dtype, ): if page_size is not None and seqlen_k % page_size != 0: @@ -601,6 +674,8 @@ def test_flash_attn_kvcache( pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -614,262 +689,314 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input( - output_unpad, indices_q, batch_size, seqlen_q - ) - else: - query_padding_mask = None - q_unpad = q - cu_seqlens_q, max_seqlen_q = None, None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - cu_seqlens_k_new = None - key_new_padding_mask = None - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) - v_unpad, *rest = unpad_input(v, key_new_padding_mask) + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: - k_unpad, v_unpad = k, v - else: - k, v, k_unpad, v_unpad = None, None, None, None - if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - page_table = None - else: - ( - k_cache, - v_cache, - page_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, device, dtype_ref - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - if not new_kv: - key_padding_mask = arange < cache_seqlens_expanded - else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new - key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if page_size is None else num_blocks * page_size, - rotary_dim // 2, - device=device, - ) - * 2 - * math.pi - ) - cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, - ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved - ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens - ) - k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") - v_to_update = rearrange(v, "b s ... -> (b s) ...") - if varlen_q: - k_to_update = k_to_update[indices_k] - v_to_update = v_to_update[indices_k] - k_cache_ref[update_mask] = k_to_update - v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None - ) - q = q.to(dtype) - q_unpad = q_unpad.to(dtype) if varlen_q else None - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) - k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None - v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None - k = k.to(dtype) if k is not None else None - v = v.to(dtype) if v is not None else None - k_unpad = k_unpad.to(dtype) if k_unpad is not None else None - v_unpad = v_unpad.to(dtype) if v_unpad is not None else None - cos = cos.to(dtype) if cos is not None else None - sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) - # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() - if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False, True] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print(f"{num_splits = }, {precompute_metadata = }") + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, + max_seqlen_q if varlen_q else seqlen_q, + seqlen_k if page_size is None else page_table.shape[1] * page_size, + nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits, + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True, + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype - ) + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", @@ -953,7 +1080,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) - out0, lse0 = flash_attn_func(q, k, v, causal=causal) + out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq @@ -961,9 +1088,9 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): for i in range(1000): torch.random.manual_seed(42) - out, lse = flash_attn_func(q, k, v, causal=causal) + out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) - assert torch.equal(lse, lse0) + # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) @@ -990,12 +1117,12 @@ def attention_combine_ref(out_partial, lse_partial): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024, 2048]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 155]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): @@ -1032,3 +1159,43 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype): # # pytorch_profiler(torch.sum, lse_partial) # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) # pytorch_profiler(torch.sum, out_partial) + +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/hopper/test_util.py b/hopper/test_util.py index 54eb195eb3..7331ea62ca 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -30,22 +30,23 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked @@ -57,6 +58,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( @@ -67,6 +69,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( @@ -134,6 +137,7 @@ def generate_qkv( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -143,6 +147,7 @@ def generate_qkv( q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), + qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, @@ -185,6 +190,39 @@ def construct_local_mask( ) +def construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # Subtract remainder instead of divide and then multiply to take care of negative values + col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk + return torch.logical_or( + col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk + ) + + def attention_ref( q, k, @@ -196,8 +234,10 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size + attention_chunk=0, sink_token_length=0, softcap=0.0, upcast=True, @@ -208,7 +248,8 @@ def attention_ref( Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -221,7 +262,7 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. Output: - output: (batch_size, seqlen_q, nheads, head_dim) + output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -229,9 +270,11 @@ def attention_ref( dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2]) - q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype) + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.float() * q_descale).to(q.dtype) + qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: @@ -240,14 +283,19 @@ def attention_ref( k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -259,6 +307,18 @@ def attention_ref( key_leftpad=key_leftpad, device=q.device, ) + if attention_chunk > 0: + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias @@ -271,7 +331,7 @@ def attention_ref( if key_padding_mask is not None: attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: + if local_mask is not None: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 0b74d0e1f1..3c9e42996b 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -8,6 +8,7 @@ #include "cutlass/arch/barrier.h" #include "named_barrier.hpp" +#include "utils.h" namespace flash { @@ -19,10 +20,15 @@ struct TileSchedulerArguments { int const num_blocks, num_head, num_batch, num_splits; int const qhead_per_khead; int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr - int const seqlen_k, headdim, element_size; // Used to calculate L2 swizzling + int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; - int* const cu_seqlens = nullptr; - int* const seqused = nullptr; + int const* const cu_seqlens = nullptr; + int const* const seqused = nullptr; + int const* const num_splits_dynamic_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + // int const* const num_n_blocks_ptr = nullptr; + int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -40,16 +46,20 @@ class SingleTileScheduler { int const qhead_per_khead; int const seqlen; cutlass::FastDivmod nsplits_divmod; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr = nullptr; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { + assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); + assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), - !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; + !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, + args.num_splits_dynamic_ptr}; } static dim3 @@ -61,24 +71,18 @@ class SingleTileScheduler { int block_idx = 0; int bidh = 0; int bidb = 0; - bool is_valid_tile = false; + int split_idx = 0; CUTLASS_DEVICE bool is_valid(Params const& params) const { - return is_valid_tile; + return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { - if constexpr (!Split) { - return {block_idx, bidh, bidb, 0 /*split_idx*/}; - } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - return {block_idx, bidh_actual, bidb, split_idx}; - } + return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; } }; @@ -90,14 +94,27 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { - WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; + if constexpr (Split) { + int split_idx; + work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); + work_info.split_idx = split_idx; + } + bool is_valid_tile = true; if constexpr (Varlen) { int seqlen = params.seqused ? params.seqused[work_info.bidb] : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - work_info.is_valid_tile = work_info.block_idx * kBlock < seqlen; + is_valid_tile = work_info.block_idx * kBlock < seqlen; } + if constexpr (Varlen && Split) { + int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; + is_valid_tile &= work_info.split_idx < num_splits_dynamic; + // Use the top 16 bits to store num_splits + work_info.split_idx |= (num_splits_dynamic << 16); + } + work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; } @@ -113,7 +130,7 @@ class SingleTileScheduler { CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { - return {-1, -1, -1, false}; + return {0, 0, -1, 0}; } }; @@ -196,6 +213,8 @@ class StaticPersistentTileScheduler { }; +/////////////////////////////////////////////////////////////////////////////// + template class DynamicPersistentTileScheduler { @@ -232,11 +251,14 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * args.headdim * args.element_size * 2; + long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead - int const swizzle = (1 << cutlass::find_log2(size_l2 / size_one_kv_head)) * (PackGQA ? 1 : args.qhead_per_khead); + // Need to be careful about the case where only one head will fit + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; @@ -305,7 +327,7 @@ class DynamicPersistentTileScheduler { void init_consumer() const { if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty } } @@ -324,24 +346,128 @@ class DynamicPersistentTileScheduler { if constexpr (IsProducerWarp) { // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % NumProducerThreads == 0) { *tile_count_smem = current_work.tile_idx; } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return {new_tile_idx}; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int tile_idx = *tile_count_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return {tile_idx}; } } }; +/////////////////////////////////////////////////////////////////////////////// + +class SingleTileBwdLPTScheduler { + +public: + + using SharedStorage = int; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; + cutlass::FastDivmod const l2_minor_residual_divmod; + int const num_hb_quotient; + }; + + static Params + to_underlying_arguments(TileSchedulerArguments const& args) { + // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k + long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); + long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum + // Swizzle is the size of each "section". Round swizzle to a power of 2 + // Need to be careful about the case where only one head will fit + auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // Seems faster if swizzle if a power of 2 + int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head)); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; + // printf("num_blocks = %d, num_head = %d, num_batch = %d, size_one_head = %d, ratio = %d, swizzle = %d, num_hb_remainder = %d\n", args.num_blocks, args.num_head, args.num_batch, size_one_head, size_l2 / size_one_head, swizzle, num_hb_remainder); + assert(args.tile_count_semaphore != nullptr); + return {args.num_blocks * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), + cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), + // don't divide by 0 + cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), + (args.num_head * args.num_batch) / swizzle}; + } + + static dim3 + get_grid_shape(Params const& params, int num_sm) { + return {uint32_t(params.total_blocks)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } -template + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int block, bidh, bidb; + int l2_mod, bidhb, bidhb_residual; + bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); + // If we're in the last section (called residual), we don't want to divide by + // swizzle. Instead we want to divide by the remainder. + if (bidhb < params.num_hb_quotient) { + block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); + } else { + block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); + } + bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); + return {block, bidh, bidb, 0 /*split_idx*/}; + } + + }; + + CUTLASS_DEVICE + SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { } + + template + CUTLASS_DEVICE + WorkTileInfo + get_initial_work(Params const& params) const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {params.total_blocks}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -360,10 +486,17 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + // int const max_kvblocks_in_l2; + cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; - int* const cu_seqlens; - int* const seqused; + int const* const cu_seqlens; + int const* const seqused; + int const* const num_splits_dynamic_ptr; + int const* const num_m_blocks_ptr; + int const* const varlen_batch_idx_ptr; + // int const* const num_n_blocks_ptr; + int const* const num_nheads_in_l2_ptr; }; static Params @@ -371,10 +504,22 @@ class VarlenDynamicPersistentTileScheduler { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); - return {args.num_head * (!Split ? 1 : args.num_splits), args.num_batch, + assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx + assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; + // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; + return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + // max_kvblocks_in_l2, + cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), - args.tile_count_semaphore, args.cu_seqlens, args.seqused}; + args.tile_count_semaphore, args.cu_seqlens, args.seqused, + args.num_splits_dynamic_ptr, + args.num_m_blocks_ptr, + args.varlen_batch_idx_ptr, + // aras.num_n_blocks_ptr, + args.num_nheads_in_l2_ptr}; } static dim3 @@ -395,12 +540,29 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { + auto get_actual_batch = [&](int virtual_batch) { + if constexpr(Prepared && Sort) { + return params.varlen_batch_idx_ptr[virtual_batch]; + } else { + return virtual_batch; + } + }; if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; + return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { - int split_idx; - int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); - return {block, bidh_actual, bidb, split_idx}; + // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh); + uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; + int bidh_actual = reinterpret_cast(bidh_actual_u); + // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx + uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); + int split_idx = reinterpret_cast(split_idx_u); + // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); + // if (threadIdx.x == 128) { + // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); + // } + return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; @@ -411,43 +573,65 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE WorkTileInfo tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { - auto prefix_sum = [](int val) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { - int32_t partial_sum = __shfl_up_sync(0xffffffff, val, i); - if (lane >= i) { val += partial_sum; } + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + auto get_num_m_blocks = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + if constexpr (Prepared) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; + } else { + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlockM) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlockM) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; } - return val; }; - auto get_num_m_blocks = [&](int bidb_start) { - int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - int seqlen; - if (params.seqused) { - seqlen = lane + bidb_start < params.num_batch ? params.seqused[lane + bidb_start] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = lane + bidb_start <= params.num_batch ? params.cu_seqlens[lane + bidb_start] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; + auto get_num_splits = [&] (int bidb_start) { + int batch_idx = lane + bidb_start; + bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; + if constexpr (!Split) { + return is_valid ? 1 : 0; + } else if constexpr(Prepared) { + return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; } else { - seqlen = params.seqlen; + return is_valid ? params.nsplits_divmod.divisor : 0; } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } - return lane + bidb_start < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane + int num_splits = get_num_splits(current_work.bidb); + int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; // Cumulative number of blocks for the next 31 batches - int num_m_blocks_cumulative = prefix_sum(num_m_blocks); + int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); - int group_end_tile = current_work.tile_idx - current_work.block - current_work.bidh * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // Only the lower 16 bits are the actual bidh + // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // if constexpr (Split) { + // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + // } + // NEW: current_work.tile_idx holds group_start_tile for starting batch + int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); + // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); // } + // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } while (group_end_tile <= next_tile_idx) { bidb += cutlass::NumThreadsPerWarp - 1; if (bidb >= params.num_batch) { @@ -457,7 +641,9 @@ class VarlenDynamicPersistentTileScheduler { return {next_tile_idx, 0, 0, params.num_batch}; } num_m_blocks = get_num_m_blocks(bidb); - num_m_blocks_cumulative = prefix_sum(num_m_blocks); + num_splits = get_num_splits(bidb); + num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; + num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); group_end_tile += m_blocks_in_group * params.num_head; // if (blockIdx.x <= 9 && threadIdx.x == 0) { @@ -468,15 +654,85 @@ class VarlenDynamicPersistentTileScheduler { // The next problem to process is the first one that does not have ending tile position // that is greater than or equal to tile index. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); + // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("blockIdx.x = %d, threadIdx.x = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; + if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } + group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int mh_block = next_tile_idx - group_start_tile; + int block, bidh; + if constexpr (LPT) { + if (!Split || num_splits == 1) { + // NOTE: code for computing nheads_in_l2 directly left as reference + // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; + // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks + // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); + // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } + // nheads_in_l2 = min(nheads_in_l2, params.num_head); + auto get_nheads_in_l2 = [&](int batch_idx) { + if constexpr(Prepared) { + return params.num_nheads_in_l2_ptr[batch_idx]; + } else { + return !PackGQA ? params.qhead_per_khead : 1; + } + }; + int nheads_in_l2 = get_nheads_in_l2(bidb); + int mh_in_l2 = nheads_in_l2 * num_m_blocks; + int section_idx = mh_block / mh_in_l2; + int l2_mod = mh_block - section_idx * mh_in_l2; + // tail section + int nheads_remainder = params.num_head - section_idx * nheads_in_l2; + int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; + block = l2_mod / nheads_in_this_section; + int bidh_residual = l2_mod - block * nheads_in_this_section; + bidh = section_idx * nheads_in_l2 + bidh_residual; + if constexpr(Split) { + // remember to set num_splits = 1 in work tile + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } else { + // NOTE: leave traverse heads first version for reference + // block = params.head_divmod.divmod(bidh, mh_block); + // if constexpr (Split) { + // int split_idx = block / num_m_blocks; + // block = block - split_idx * num_m_blocks; + // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // bidh = reinterpret_cast(bidh_packed); + // } + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } + block = num_m_blocks - 1 - block; + } else { + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); + // } + } + return {group_start_tile, block, bidh, bidb}; } template @@ -488,7 +744,7 @@ class VarlenDynamicPersistentTileScheduler { if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { return get_next_work(params, {0, 0, 0, 0}); @@ -518,20 +774,22 @@ class VarlenDynamicPersistentTileScheduler { int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { - flash::named_barrier_sync(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int4 work_info = *work_info_smem; - flash::named_barrier_arrive(NumThreads, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; } } }; +/////////////////////////////////////////////////////////////////////////////// + } // flash diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 127f518bbb..8353542c47 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -6,23 +6,35 @@ #include -// Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( - int headdim, bool is_causal, bool is_local, int element_size=2, - bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, + bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - return {192, 128, true, true}; + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + if (headdim_v == 512) { + return {64, 64, false, false}; + } else if (headdim_v == 256) { + return {128, 96, true, false}; + } else { + // Switch to tile size 192 x 192 for now + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { - return {192, is_local || paged_kv ? 128 : 144, false, true}; + return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; - // {128, 192, false, false} and {192, 128, false, true} are quite good too - // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -32,18 +44,18 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) { - return {128, paged_kv ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; + return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } else if (headdim <= 192) { - return {128, (paged_kv || softcap) && is_local ? 128 : 160, true, true}; + return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { - return {128, is_local ? 64 : 128, true, !paged_kv}; // PagedKV uses more registers so we disabled IntraWGOverlap + return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap } } } // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} constexpr std::tuple tile_size_fwd_sm8x( - bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2, + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool paged_kv=false, bool varlen_and_split=false, bool softcap=false, bool append_kv=false) { if (element_size == 2) { diff --git a/hopper/utils.h b/hopper/utils.h index fa8938c853..a568e3075a 100644 --- a/hopper/utils.h +++ b/hopper/utils.h @@ -21,17 +21,7 @@ #include #include - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while(0) - -#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) +#include "cuda_check.h" namespace flash { @@ -109,6 +99,25 @@ static __device__ __forceinline__ T run(T x, Operator &op) { //////////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_HOST_DEVICE +int div_floor(cutlass::FastDivmod const& divmod, int dividend) { + // Take care of the negative case: https://stackoverflow.com/questions/39304681/division-with-negative-dividend-but-rounded-towards-negative-infinity + // Maybe the compiler will turn the -1 - * into bit negation operation, I haven't checked. + return dividend >= 0 ? divmod.divide(dividend) : -1 - divmod.divide(-1 - dividend); +} + +CUTLASS_HOST_DEVICE +int round_down(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend) * divmod.divisor; +} + +CUTLASS_HOST_DEVICE +int round_up(cutlass::FastDivmod const& divmod, int dividend) { + return div_floor(divmod, dividend - 1) * divmod.divisor + divmod.divisor; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) template @@ -272,9 +281,11 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } + static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); + static constexpr int kMaxKIters = 16; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { if constexpr (!SwapAB) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } else { @@ -282,6 +293,22 @@ CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } + // In the case of large kNumKIters, the compiler chooses to store the smem addresses + // in registers, causing spills. This loop forces the compiler to recompute the addresses. + if constexpr (kNumKIters > kMaxKIters) { + // This will always be zero, just a way to force the compiler to recompute the smem + // addresses. This results in USEL instructions. There's probably a better way to do this. + int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; + CUTLASS_PRAGMA_UNROLL + for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { + if constexpr (!SwapAB) { + cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); + } else { + cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); @@ -354,6 +381,69 @@ CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Ten } } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + static constexpr int rA = decltype(rank(tA))::value; + static constexpr int rB = decltype(rank(tB))::value; + static constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -562,6 +652,26 @@ CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTE_DEVICE T warp_prefix_sum(T val) { + int lane = threadIdx.x % cutlass::NumThreadsPerWarp; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { + T partial_sum = __shfl_up_sync(0xffffffff, val, i); + if (lane >= i) { val += partial_sum; } + } + return val; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + CUTLASS_DEVICE int canonical_warp_group_idx_nosync() { return threadIdx.x / cutlass::NumThreadsPerWarpGroup; diff --git a/setup.py b/setup.py index a802a7e65e..a108c412c0 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" - +SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False @functools.lru_cache(maxsize=None) def cuda_archs() -> str: @@ -121,7 +121,7 @@ def check_if_rocm_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - nvcc_threads = os.getenv("NVCC_THREADS") or "2" + nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] @@ -132,7 +132,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] # Validate if each element in archs is in allowed_archs assert all( @@ -145,11 +145,20 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -if IS_ROCM: - if not USE_TRITON_ROCM: - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) +if os.path.isdir(".git"): + if not SKIP_CK_BUILD: + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) else: - subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + if IS_ROCM: + if not SKIP_CK_BUILD: + assert ( + os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") + ), "csrc/composable_kernel is missing, please use source distribution or git clone" + else: + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) @@ -186,6 +195,34 @@ def validate_and_update_archs(archs): # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + + nvcc_flags = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", + # "--ptxas-options=-O2", + # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", + ] + + compiler_c17_flag=["-O3", "-std=c++17"] + # Add Windows-specific flags + if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': + nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) + compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] + ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", @@ -199,8 +236,6 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", @@ -213,8 +248,6 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", @@ -227,8 +260,6 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", @@ -241,8 +272,6 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", @@ -255,8 +284,6 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", @@ -269,38 +296,14 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - # "--ptxas-options=-v", - # "--ptxas-options=-O2", - # "-lineinfo", - # "-DFLASHATTENTION_DISABLE_BACKWARD", - # "-DFLASHATTENTION_DISABLE_DROPOUT", - # "-DFLASHATTENTION_DISABLE_ALIBI", - # "-DFLASHATTENTION_DISABLE_SOFTCAP", - # "-DFLASHATTENTION_DISABLE_UNEVEN_K", - # "-DFLASHATTENTION_DISABLE_LOCAL", - ] - + cc_flag - ), + "cxx": compiler_c17_flag, + "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", @@ -314,20 +317,19 @@ def validate_and_update_archs(archs): TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) - if USE_TRITON_ROCM: - # Skip C++ extension compilation if using Triton Backend - pass - else: + # Skips CK C++ extension compilation if using Triton Backend + if not SKIP_CK_BUILD: ck_dir = "csrc/composable_kernel" #use codegen get code dispatch if not os.path.exists("./build"): os.makedirs("build") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + optdim = os.getenv("OPT_DIM", "32,64,128,256") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 @@ -340,7 +342,11 @@ def validate_and_update_archs(archs): archs = os.getenv("GPU_ARCHS", "native").split(";") validate_and_update_archs(archs) - cc_flag = [f"--offload-arch={arch}" for arch in archs] + if archs != ['native']: + cc_flag = [f"--offload-arch={arch}" for arch in archs] + else: + arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] + cc_flag = [f"--offload-arch={arch}"] # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI @@ -387,6 +393,8 @@ def validate_and_update_archs(archs): # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 hip_version = get_hip_version() + if hip_version > Version('5.5.00000'): + cc_flag += ["-mllvm", "--lsr-drop-solution=1"] if hip_version > Version('5.7.23302'): cc_flag += ["-fno-offload-uniform-block"] if hip_version > Version('6.1.40090'): diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py new file mode 100644 index 0000000000..f3042f0763 --- /dev/null +++ b/tests/cute/test_flash_attn.py @@ -0,0 +1,1038 @@ +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import math +import itertools + +import pytest +import torch + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.utils.testing import attention_ref, generate_qkv, generate_random_padding_mask +from flash_attn.cute.interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [128, 192]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype +): + if (causal or local) and seqlen_k < seqlen_q: + pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() + # # if qv is not None: + # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # num_splits_vals = [1, 3] + pack_gqa_vals = [False, True, None] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + # qv=qv, + # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + softcap=softcap, + learnable_sink=learnable_sink, + # pack_gqa=pack_gqa, + # num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and softcap == 0.0 + and not local + and dv == d + and learnable_sink is None + and False + ): + g = torch.randn_like(out) + # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mqa"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("softcap", [0.0, 15.0]) +@pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", [128, 192]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + # (1, 1), + # (1, 3), + # (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype +): + if (causal or local): # Right now we only support causal attention with seqlen_k == seqlen_q + seqlen_k = seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + batch_size = 49 if seqlen_q <= 1024 else 7 + nheads = 6 + # batch_size = 1 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + # TODO: test zero_lengths + key_padding_mask = generate_random_padding_mask( + # seqlen_k, batch_size, device, mode="random", zero_lengths=True + seqlen_k, batch_size, device, mode="random", zero_lengths=False + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + # query_padding_mask[:] = True + # query_unused_mask = None + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + if causal or local: + key_padding_mask = query_padding_mask + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True, None] + # num_splits_vals = [1, 3] + num_splits_vals = [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + # max_seqlen_k, + # seqused_q=seqused_q, + # seqused_k=seqused_k, + causal=causal, + # qv=qv_unpad, + # q_descale=q_descale, + # k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + # attention_chunk=attention_chunk, + learnable_sink=learnable_sink, + softcap=softcap, + pack_gqa=pack_gqa, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + and dv == d + and not has_learnable_sink + and False + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_learnable_sink", [False, True]) +# @pytest.mark.parametrize("has_learnable_sink", [False]) +# @pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("new_kv", [False]) +@pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("rotary_interleaved", [False, True]) +@pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("rotary_fraction", [0.0]) +# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) +@pytest.mark.parametrize("page_size", [None, 128]) +# @pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("has_leftpad", [False, True]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("varlen_q", [False, True]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # # (1, 128 * 1024), + # # (16, 128 * 1024), + # (128, 128), + # (256, 512), # To test appending KV with more than 1 block + # (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + has_learnable_sink, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + dv_vals = [d] + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] + attention_chunk_vals = [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + # has_qv = d == 64 and dv >= 256 + has_qv = False + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + if has_learnable_sink: + learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + learnable_sink = None + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + learnable_sink=learnable_sink, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + # num_splits_vals = [1, 0] + num_splits_vals = [1] + # precompute_metadata_vals = [False, True] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + # if precompute_metadata: + # scheduler_metadata = get_scheduler_metadata( + # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + # max_seqlen_k_new=seqlen_new, page_size=page_size, + # causal=causal, window_size=window_size, attention_chunk=attention_chunk, + # num_splits=num_splits + # ) + # else: + # scheduler_metadata = None + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + # out, lse, *rest = flash_attn_with_kvcache( + out, lse, *rest = flash_attn_varlen_func( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + # k if not new_kv or not varlen_q else k_unpad, + # v if not new_kv or not varlen_q else v_unpad, + # qv=qv if not varlen_q else qv_unpad, + # rotary_cos=cos, + # rotary_sin=sin, + seqused_k=cache_seqlens, + # cache_batch_idx=cache_batch_idx, + # cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + # cu_seqlens_k_new=cu_seqlens_k_new, + # rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + learnable_sink=learnable_sink, + # attention_chunk=attention_chunk, + # rotary_interleaved=rotary_interleaved, + # scheduler_metadata=scheduler_metadata, + # num_splits=num_splits, + # return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, seqlen, nheads) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [11]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + + # Test with LSE returned (default behavior) + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # Test with LSE not returned + out_no_lse, lse_no_lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=False) + assert lse_no_lse is None, "LSE should be None when return_lse=False" + assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), "Output should be the same regardless of return_lse" diff --git a/tests/ops/triton/test_layer_norm.py b/tests/ops/triton/test_layer_norm.py index 3d92b6b329..2400132764 100644 --- a/tests/ops/triton/test_layer_norm.py +++ b/tests/ops/triton/test_layer_norm.py @@ -16,8 +16,10 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 +# @pytest.mark.parametrize("zero_centered_weight", [False, True]) +@pytest.mark.parametrize("zero_centered_weight", [False]) @pytest.mark.parametrize("has_weight1", [False, True]) -# @pytest.mark.parametrize("has_weight1", [True]) +# @pytest.mark.parametrize("has_weight1", [False]) @pytest.mark.parametrize("has_x1", [False, True]) # @pytest.mark.parametrize("has_x1", [False]) @pytest.mark.parametrize("has_rowscale", [False, True]) @@ -25,11 +27,11 @@ @pytest.mark.parametrize("dropout_p", [0.0, 0.27]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("prenorm", [True, False]) -# @pytest.mark.parametrize("prenorm", [False]) +# @pytest.mark.parametrize("prenorm", [True]) @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize("is_rms_norm", [True]) @pytest.mark.parametrize("has_residual", [True, False]) -# @pytest.mark.parametrize("has_residual", [False]) +# @pytest.mark.parametrize("has_residual", [True]) @pytest.mark.parametrize( "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) ) @@ -41,7 +43,7 @@ ) # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)]) @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096]) -# @pytest.mark.parametrize("hidden_size", [256]) +# @pytest.mark.parametrize("hidden_size", [1024]) def test_layer_norm( hidden_size, input_dtype, @@ -54,6 +56,7 @@ def test_layer_norm( has_rowscale, has_x1, has_weight1, + zero_centered_weight, ): if has_rowscale and has_x1: pytest.skip("Not supported") @@ -145,6 +148,7 @@ def test_layer_norm( rowscale=rowscale, prenorm=prenorm, residual_in_fp32=residual_in_fp32, + zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=True, ) @@ -162,6 +166,7 @@ def test_layer_norm( dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, + zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, ) @@ -177,6 +182,7 @@ def test_layer_norm( dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, + zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, upcast=True, diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index 503b7bf01c..d5590fcfc8 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -1399,7 +1399,7 @@ def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() - assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + assert (q.grad - q_ref.grad).abs().max().item() <= 7 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py old mode 100644 new mode 100755 index d64246f950..b5e026803c --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1,6 +1,4 @@ import math -import os -import random import pytest import torch @@ -18,12 +16,7 @@ from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb -from flash_attn.flash_attn_triton_amd.utils import DEBUG - -# Test ROCM Triton Backend -USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" -if USE_TRITON_ROCM: - random.seed(42) +from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_rdna MAX_HEADDIM_SM8x = 192 @@ -572,33 +565,26 @@ def get_dropout_fraction( return dropped.sum() / valid.sum() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) -# @pytest.mark.parametrize("d", [32]) +# @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize("seqlen", [128]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.0]) +# @pytest.mark.parametrize("seqlen", [512]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): - if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -719,45 +705,35 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if DEBUG: - print("dqkv:", dqkv, dqkv.shape) - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_pt:", dqkv_pt, dqkv_pt.shape) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize('dtype', [torch.float16]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [32]) +# @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): - if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" @@ -877,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) @@ -886,23 +862,20 @@ def test_flash_attn_varlen_qkvpacked( assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -# @pytest.mark.parametrize("kvpacked", [True, False]) @pytest.mark.parametrize("kvpacked", [False]) -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("kvpacked", [False]) +@pytest.mark.parametrize("dtype", ([torch.float16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) -@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -925,22 +898,16 @@ def test_flash_attn_varlen_qkvpacked( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.0]) -# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported on AMD's Triton Backend yet") - - if softcap != 0.0: - pytest.skip("softcap not supported on AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") + if causal: + if seqlen_q ==1024 and seqlen_k==1024 and d==160: + pytest.skip("This test with causal=True is flakey") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1002,10 +969,6 @@ def test_flash_attn_output( deterministic=deterministic, return_attn_probs=True, ) - if DEBUG: - print("out:", out, out.shape) - print("lse:", lse, lse.shape) - if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, @@ -1160,55 +1123,37 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - if DEBUG: - print("out:", out, out.shape) - print("out_ref:", out_ref, out_ref.shape) assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) - assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() - + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize('dtype', [torch.float16]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize('mha_type', ["mha"]) -# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize('mha_type', ["mqa"]) @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("alibi", [False, True]) -@pytest.mark.parametrize("alibi", [False]) -# @pytest.mark.parametrize("local", [False, True]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) -# @pytest.mark.parametrize('d', [160]) +# @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1226,23 +1171,15 @@ def test_flash_attn_output( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) -# @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize('dropout_p', [0.0]) -# @pytest.mark.parametrize("softcap", [0.0, 50.0]) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0]) +# @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: - if dropout_p != 0.0: - pytest.skip("Dropout not supported in AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if softcap != 0.0: - pytest.skip("softcap not supported on AMD's Triton Backend yet") - + if seqlen_q == 1 and seqlen_k == 147 and kvpacked == True and dropout_p != 0.0: + pytest.skip("This config with dropout is flaky on AMD.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1347,11 +1284,6 @@ def test_flash_attn_varlen_output( deterministic=deterministic, return_attn_probs=True, ) - if DEBUG: - print("out_unpad:", out_unpad, out_unpad.shape) - print("sm_lse:", sm_lse, sm_lse.shape) - - out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( @@ -1516,44 +1448,29 @@ def test_flash_attn_varlen_output( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) - assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128]) -# @pytest.mark.parametrize("d", [32]) -# @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1571,6 +1488,10 @@ def test_flash_attn_varlen_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): + if USE_TRITON_ROCM: + if is_rdna(): + if seqlen_q == 1 and seqlen_k == 239 and d == 256: + pytest.skip("This config doesnot work on RDNA Devices.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1646,36 +1567,23 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if DEBUG: - print("dv:", dv, dv.shape) - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_pt:", dv_pt, dv_pt.shape) - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 - - if DEBUG: - print("dk:", dk, dk.shape) - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_pt:", dk_pt, dk_pt.shape) + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 - if DEBUG: - print("dq:", dq, dq.shape) - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_pt:", dq_pt, dq_pt.shape) - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -@pytest.mark.parametrize("dtype", [torch.float16]) -# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) -# @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -1692,7 +1600,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ], ) # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged -# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_causal( @@ -1834,6 +1741,136 @@ def test_flash_attn_varlen_causal( assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("deterministic", [True]) +@pytest.mark.parametrize("alibi", [False, True]) +# @pytest.mark.parametrize("alibi", [True]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (3, 1024), + (1, 339), + (64, 800), + (3, 799), + (64, 2048), + (16, 20000), + (16, 100000), + (128, 128), + (256, 256), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_splitkv( + seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype +): + if USE_TRITON_ROCM: + if seqlen_q == 1 and seqlen_k == 339 and swap_sq_sk == True: + pytest.skip("This config with is flaky on AMD.") + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 1 + nheads = 12 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + if alibi: + alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 + attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) + else: + alibi_slopes, attn_bias = None, None + out, lse, _ = flash_attn_func( + q, + k, + v, + 0.0, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=True, + ) + out_ref, attn_ref = attention_ref( + q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + attn_bias, + 0.0, + None, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + mult = 2 if not alibi else 8 + assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 + assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 + assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 + + # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) @@ -1850,15 +1887,15 @@ def test_flash_attn_varlen_causal( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -# @pytest.mark.parametrize("rotary_interleaved", [False, True]) -@pytest.mark.parametrize("rotary_interleaved", [False]) -# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) -@pytest.mark.parametrize("rotary_fraction", [0.0]) -# @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) +@pytest.mark.parametrize("rotary_interleaved", [False, True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("paged_kv_block_size", [None]) -# @pytest.mark.parametrize("has_leftpad", [False, True]) +# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) +# @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_leftpad", [True]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @@ -1901,18 +1938,6 @@ def test_flash_attn_kvcache( num_splits, dtype, ): - if USE_TRITON_ROCM: - if paged_kv_block_size is not None: - pytest.skip("paged attention not supported on AMD's Triton Backend yet") - - if local == True: - pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet") - - if rotary_interleaved == True or rotary_fraction > 0.0: - pytest.skip("rotary embedding not supported on AMD's Triton Backend yet") - - if has_leftpad == True: - pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: @@ -2157,3 +2182,366 @@ def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, )[:, :seqlen_k] return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + ], +) +@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +@pytest.mark.skip() +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + g = torch.randn_like(out0) + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + ( + dq0, + dk0, + dv0, + ) = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(250): + torch.random.manual_seed(42) + out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) + assert torch.equal(out, out0) + assert torch.equal(lse, lse0) + + if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [16, 32, 64]) +# @pytest.mark.parametrize('d', [16]) +@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) +# @pytest.mark.parametrize('seqlen', [2]) +@pytest.mark.skip() +def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): + """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, + in the case where seqlen % 128 != 0. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + nheads = 5 + q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 + k, v = [ + torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 + for _ in range(2) + ] + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + out = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out) + out.backward(g) + q_pt = q.detach().clone().requires_grad_(True) + k_pt = k.detach().clone().requires_grad_(True) + v_pt = v.detach().clone().requires_grad_(True) + out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + out_pt.backward(g) + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + out_ref.backward(g) + print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") + print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") + print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( + q_pt.grad - q_ref.grad + ).abs().max().item() + 1e-3 + assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( + k_pt.grad - k_ref.grad + ).abs().max().item() + 1e-3 + assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( + v_pt.grad - v_ref.grad + ).abs().max().item() + 1e-3 + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [64, 128]) +# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) +# @pytest.mark.parametrize('seqlen', [128]) +@pytest.mark.skip() +def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): + """We previously had a bug where we were using the wrong strides of dout, which shows up + when dout is not contiguous. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + nheads = 2 + q, k, v = [ + torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) + for _ in range(3) + ] + out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") + # So g is not contiguous + g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] + out.backward(g) + q_pt = q.detach().clone().requires_grad_(True) + k_pt = k.detach().clone().requires_grad_(True) + v_pt = v.detach().clone().requires_grad_(True) + out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) + out_pt = rearrange(out_pt, "b s ... -> s b ...") + out_pt.backward(g) + q_ref = q.detach().clone().requires_grad_(True) + k_ref = k.detach().clone().requires_grad_(True) + v_ref = v.detach().clone().requires_grad_(True) + out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) + out_ref = rearrange(out_ref, "b s ... -> s b ...") + out_ref.backward(g) + print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") + print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") + print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( + q_pt.grad - q_ref.grad + ).abs().max().item() + assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( + k_pt.grad - k_ref.grad + ).abs().max().item() + assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( + v_pt.grad - v_ref.grad + ).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [16, 32, 64]) +# @pytest.mark.parametrize('d', [16]) +@pytest.mark.skip() +def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): + """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, + in the case where seqlen % 128 != 0 or varlen. + """ + device = "cuda" + # set seed + torch.random.manual_seed(0) + nheads = 5 + q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) + k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) + Mq = 256 + Mk = 3 + + q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 + k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) + g = torch.randn_like(out) + out.backward(g) + + assert not q.grad.isnan().any() + assert not k.grad.isnan().any() + assert not v.grad.isnan().any() + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [False]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 4 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) + + g = torch.randn_like(out) + dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) + + +@pytest.mark.parametrize("dtype", ([torch.float16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("swap_sq_sk", [False]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +@pytest.mark.skip() +def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 2 + nheads = 9 + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + window_size=window_size, + deterministic=True, + ) + + g = torch.randn_like(out) + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + for _ in range(50): + dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 0676d329c6..0b44744cb4 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -5,6 +5,9 @@ import torch import torch.nn.functional as F from einops import rearrange + +import triton + from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_ from flash_attn.bert_padding import pad_input, unpad_input @@ -45,7 +48,7 @@ def index_cos_sin(cos, sin, seqlen_offsets, seqlen): @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) -# @pytest.mark.parametrize('dtype', ([torch.float16])) +# @pytest.mark.parametrize('dtype', ([torch.bfloat16])) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) # @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) @@ -97,6 +100,8 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.float16])) +@pytest.mark.parametrize("compiled", [False, True]) +# @pytest.mark.parametrize("compiled", [True]) @pytest.mark.parametrize("gqa", [False, True]) # @pytest.mark.parametrize("gqa", [False]) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) @@ -105,7 +110,9 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t # @pytest.mark.parametrize('rotary_fraction', [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize('interleaved', [False]) -def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype): +def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, compiled, dtype): + if compiled: # Don't fall back to eager just bc of recompilation + torch._dynamo.config.recompile_limit = 2 ** 31 rtol = 1e-3 batch_size = 32 nheads = 4 @@ -126,7 +133,8 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, qkv_pt = qkv.detach().clone().requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) - out = apply_rotary_emb_qkv_( + fn = apply_rotary_emb_qkv_ if not compiled else torch.compile(apply_rotary_emb_qkv_) + out = fn( qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, num_heads_q=None if not gqa else nheads ) @@ -271,7 +279,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of def test_compilation_count(): - batch_size = 1 + nheads = 4 headdim = 128 device = "cuda" dtype = torch.float16 @@ -288,11 +296,17 @@ def count_compilations(*args, **kwargs): old_cache_func = JITFunction.cache_hook try: - rotary_kernel.cache.clear() + if hasattr(rotary_kernel, "cache"): + rotary_kernel.cache.clear() + else: # Triton 3.3 replaces cache with per-device device_caches + device = triton.runtime.driver.active.get_current_device() + # device_caches[device] returns a 4-tuple: (kernel_cache, target, backend, binder) + rotary_kernel.device_caches[device][0].clear() + JITFunction.cache_hook = count_compilations for seqlen in (128, 256): - for nheads in (4, 32): + for batch_size in (4, 32): x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) x.requires_grad_() cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) diff --git a/usage.md b/usage.md index 133bfbdb6b..6cd2365241 100644 --- a/usage.md +++ b/usage.md @@ -1,8 +1,7 @@ # FlashAttention adoption We've been very happy to see FlashAttention being adopted by many organizations -and research labs to speed up their training / inference (within 6 months after -FlashAttention's release, at the time of writing). +and research labs to speed up their training / inference. This page contains a partial list of places where FlashAttention is being used. If you'd like to add links to your organization / product / codebase, please open a PR or email us. We'd very much like to hear from you!